1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -18,6 +18,7 @@ from typing import List, Tuple
from rl_coach.base_parameters import Frameworks, AgentParameters
from rl_coach.logger import failed_imports
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
try:
import tensorflow as tf
@@ -251,3 +252,25 @@ class NetworkWrapper(object):
result.append(str(self.online_network))
result.append("")
return '\n'.join(result)
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collect all of network's savers for global or online network
Note: global, online, and target network are all copies fo the same network which parameters that are
updated at different rates. So we only need to save one of the networks; the one that holds the most
recent parameters. target network is created for some agents and used for stabilizing training by
updating parameters from online network at a slower rate. As a result, target network never contains
the most recent set of parameters. In single-worker training, no global network is created and online
network contains the most recent parameters. In vertical distributed training with more than one worker,
global network is updated by all workers and contains the most recent parameters.
Therefore preference is given to global network if it exists, otherwise online network is used
for saving.
:param parent_path_suffix: path suffix of the parent of the network wrapper
(e.g. could be name of level manager plus name of agent)
:return: collection of all checkpoint objects
"""
if self.global_network:
savers = self.global_network.collect_savers(parent_path_suffix)
else:
savers = self.online_network.collect_savers(parent_path_suffix)
return savers