mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -18,6 +18,7 @@ from typing import List, Tuple
|
||||
|
||||
from rl_coach.base_parameters import Frameworks, AgentParameters
|
||||
from rl_coach.logger import failed_imports
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
try:
|
||||
import tensorflow as tf
|
||||
@@ -251,3 +252,25 @@ class NetworkWrapper(object):
|
||||
result.append(str(self.online_network))
|
||||
result.append("")
|
||||
return '\n'.join(result)
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collect all of network's savers for global or online network
|
||||
Note: global, online, and target network are all copies fo the same network which parameters that are
|
||||
updated at different rates. So we only need to save one of the networks; the one that holds the most
|
||||
recent parameters. target network is created for some agents and used for stabilizing training by
|
||||
updating parameters from online network at a slower rate. As a result, target network never contains
|
||||
the most recent set of parameters. In single-worker training, no global network is created and online
|
||||
network contains the most recent parameters. In vertical distributed training with more than one worker,
|
||||
global network is updated by all workers and contains the most recent parameters.
|
||||
Therefore preference is given to global network if it exists, otherwise online network is used
|
||||
for saving.
|
||||
:param parent_path_suffix: path suffix of the parent of the network wrapper
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: collection of all checkpoint objects
|
||||
"""
|
||||
if self.global_network:
|
||||
savers = self.global_network.collect_savers(parent_path_suffix)
|
||||
else:
|
||||
savers = self.online_network.collect_savers(parent_path_suffix)
|
||||
return savers
|
||||
|
||||
Reference in New Issue
Block a user