1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -0,0 +1,42 @@
import pytest
from rl_coach.saver import Saver, SaverCollection
@pytest.mark.unit_test
def test_checkpoint_collection():
class SaverTest(Saver):
def __init__(self, path):
self._path = path
self._count = 1
@property
def path(self):
return self._path
def merge(self, other: 'Saver'):
assert isinstance(other, SaverTest)
assert self.path == other.path
self._count += other._count
# test add
savers = SaverCollection(SaverTest('123'))
savers.add(SaverTest('123'))
savers.add(SaverTest('456'))
def check_collection(mul):
paths = ['123', '456']
for c in savers:
paths.remove(c.path)
if c.path == '123':
assert c._count == 2 * mul
elif c.path == '456':
assert c._count == 1 * mul
else:
assert False, "invalid path"
check_collection(1)
# test update
savers.update(savers)
check_collection(2)