1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -24,7 +24,9 @@ from mxnet.ndarray import NDArray
from rl_coach.architectures.architecture import Architecture
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
from rl_coach.base_parameters import AgentParameters
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list, squeeze_list
@@ -81,17 +83,25 @@ class MxnetArchitecture(Architecture):
"""
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
def _model_input_shapes(self) -> List[List[int]]:
"""
Create a list of input array shapes
:return: type of input shapes
"""
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
allowed_inputs["action"] = copy.copy(self.spaces.action)
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
embedders = self.model.nets[0].input_embedders
return list([1] + allowed_inputs[emb.embedder_name].shape.tolist() for emb in embedders)
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
"""
Creates a tuple of input arrays with correct shapes that can be used for shape inference
of the model weights and for printing the summary
:return: tuple of inputs for model forward pass
"""
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
allowed_inputs["action"] = copy.copy(self.spaces.action)
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
embedders = self.model.nets[0].input_embedders
inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders)
input_shapes = self._model_input_shapes()
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
return inputs
def construct_model(self) -> None:
@@ -402,3 +412,21 @@ class MxnetArchitecture(Architecture):
:return: None
"""
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collection of all checkpoints for the network (typically only one checkpoint)
:param parent_path_suffix: path suffix of the parent of the network
(e.g. could be name of level manager plus name of agent)
:return: checkpoint collection for the network
"""
name = self.name.replace('/', '.')
savers = SaverCollection(ParameterDictSaver(
name="{}.{}".format(parent_path_suffix, name),
param_dict=self.model.collect_params()))
if self.ap.task_parameters.export_onnx_graph:
savers.add(OnnxSaver(
name="{}.{}.onnx".format(parent_path_suffix, name),
model=self.model,
input_shapes=self._model_input_shapes()))
return savers