mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -23,6 +23,7 @@ import tensorflow as tf
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
|
||||
|
||||
@@ -637,6 +638,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.curr_rnn_c_in = self.middleware.c_init
|
||||
self.curr_rnn_h_in = self.middleware.h_init
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collection of all checkpoints for the network (typically only one checkpoint)
|
||||
:param parent_path_suffix: path suffix of the parent of the network
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: checkpoint collection for the network
|
||||
"""
|
||||
# TODO implement returning checkpoints for tensorflow
|
||||
return SaverCollection()
|
||||
|
||||
|
||||
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user