mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -24,7 +24,9 @@ from mxnet.ndarray import NDArray
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list
|
||||
|
||||
@@ -81,17 +83,25 @@ class MxnetArchitecture(Architecture):
|
||||
"""
|
||||
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||
|
||||
def _model_input_shapes(self) -> List[List[int]]:
|
||||
"""
|
||||
Create a list of input array shapes
|
||||
:return: type of input shapes
|
||||
"""
|
||||
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
||||
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
||||
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
||||
embedders = self.model.nets[0].input_embedders
|
||||
return list([1] + allowed_inputs[emb.embedder_name].shape.tolist() for emb in embedders)
|
||||
|
||||
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
|
||||
"""
|
||||
Creates a tuple of input arrays with correct shapes that can be used for shape inference
|
||||
of the model weights and for printing the summary
|
||||
:return: tuple of inputs for model forward pass
|
||||
"""
|
||||
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
||||
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
||||
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
||||
embedders = self.model.nets[0].input_embedders
|
||||
inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders)
|
||||
input_shapes = self._model_input_shapes()
|
||||
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
|
||||
return inputs
|
||||
|
||||
def construct_model(self) -> None:
|
||||
@@ -402,3 +412,21 @@ class MxnetArchitecture(Architecture):
|
||||
:return: None
|
||||
"""
|
||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collection of all checkpoints for the network (typically only one checkpoint)
|
||||
:param parent_path_suffix: path suffix of the parent of the network
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: checkpoint collection for the network
|
||||
"""
|
||||
name = self.name.replace('/', '.')
|
||||
savers = SaverCollection(ParameterDictSaver(
|
||||
name="{}.{}".format(parent_path_suffix, name),
|
||||
param_dict=self.model.collect_params()))
|
||||
if self.ap.task_parameters.export_onnx_graph:
|
||||
savers.add(OnnxSaver(
|
||||
name="{}.{}.onnx".format(parent_path_suffix, name),
|
||||
model=self.model,
|
||||
input_shapes=self._model_input_shapes()))
|
||||
return savers
|
||||
|
||||
Reference in New Issue
Block a user