mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
42
rl_coach/tests/test_saver.py
Normal file
42
rl_coach/tests/test_saver.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from rl_coach.saver import Saver, SaverCollection
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_checkpoint_collection():
|
||||
class SaverTest(Saver):
|
||||
def __init__(self, path):
|
||||
self._path = path
|
||||
self._count = 1
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._path
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
assert isinstance(other, SaverTest)
|
||||
assert self.path == other.path
|
||||
self._count += other._count
|
||||
|
||||
# test add
|
||||
savers = SaverCollection(SaverTest('123'))
|
||||
savers.add(SaverTest('123'))
|
||||
savers.add(SaverTest('456'))
|
||||
|
||||
def check_collection(mul):
|
||||
paths = ['123', '456']
|
||||
for c in savers:
|
||||
paths.remove(c.path)
|
||||
if c.path == '123':
|
||||
assert c._count == 2 * mul
|
||||
elif c.path == '456':
|
||||
assert c._count == 1 * mul
|
||||
else:
|
||||
assert False, "invalid path"
|
||||
|
||||
check_collection(1)
|
||||
|
||||
# test update
|
||||
savers.update(savers)
|
||||
check_collection(2)
|
||||
Reference in New Issue
Block a user