mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
21
rl_coach/tests/test_utils.py
Normal file
21
rl_coach/tests/test_utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from rl_coach import utils
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_checkpoint_state_default():
|
||||
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext']
|
||||
checkpoint_state = utils.get_checkpoint_state(files)
|
||||
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
|
||||
assert checkpoint_state.all_model_checkpoint_paths == [f[:-4] for f in sorted(files)]
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_checkpoint_state_custom():
|
||||
files = ['prefix.4.test.ckpt.ext', 'prefix.2.test.ckpt.ext', 'prefix.3.test.ckpt.ext', 'prefix.1.test.ckpt.ext']
|
||||
assert len(utils.get_checkpoint_state(files).all_model_checkpoint_paths) == 0 # doesn't match the default pattern
|
||||
checkpoint_state = utils.get_checkpoint_state(files, filename_pattern=r'([0-9]+)[^0-9].*?\.ckpt')
|
||||
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
|
||||
assert checkpoint_state.all_model_checkpoint_paths == [f[7:-4] for f in sorted(files)]
|
||||
|
||||
Reference in New Issue
Block a user