mirror of
https://github.com/gryf/coach.git
synced 2026-02-18 15:35:56 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
@@ -213,3 +214,12 @@ class Architecture(object):
|
||||
:param placeholder: a placeholder for binding the value to assign_op.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collection of all savers for the network (typically only one saver for network and one for ONNX export)
|
||||
:param parent_path_suffix: path suffix of the parent of the network
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: saver collection for the network
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -24,7 +24,9 @@ from mxnet.ndarray import NDArray
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list
|
||||
|
||||
@@ -81,17 +83,25 @@ class MxnetArchitecture(Architecture):
|
||||
"""
|
||||
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||
|
||||
def _model_input_shapes(self) -> List[List[int]]:
|
||||
"""
|
||||
Create a list of input array shapes
|
||||
:return: type of input shapes
|
||||
"""
|
||||
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
||||
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
||||
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
||||
embedders = self.model.nets[0].input_embedders
|
||||
return list([1] + allowed_inputs[emb.embedder_name].shape.tolist() for emb in embedders)
|
||||
|
||||
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
|
||||
"""
|
||||
Creates a tuple of input arrays with correct shapes that can be used for shape inference
|
||||
of the model weights and for printing the summary
|
||||
:return: tuple of inputs for model forward pass
|
||||
"""
|
||||
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
||||
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
||||
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
||||
embedders = self.model.nets[0].input_embedders
|
||||
inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders)
|
||||
input_shapes = self._model_input_shapes()
|
||||
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
|
||||
return inputs
|
||||
|
||||
def construct_model(self) -> None:
|
||||
@@ -402,3 +412,21 @@ class MxnetArchitecture(Architecture):
|
||||
:return: None
|
||||
"""
|
||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collection of all checkpoints for the network (typically only one checkpoint)
|
||||
:param parent_path_suffix: path suffix of the parent of the network
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: checkpoint collection for the network
|
||||
"""
|
||||
name = self.name.replace('/', '.')
|
||||
savers = SaverCollection(ParameterDictSaver(
|
||||
name="{}.{}".format(parent_path_suffix, name),
|
||||
param_dict=self.model.collect_params()))
|
||||
if self.ap.task_parameters.export_onnx_graph:
|
||||
savers.add(OnnxSaver(
|
||||
name="{}.{}.onnx".format(parent_path_suffix, name),
|
||||
model=self.model,
|
||||
input_shapes=self._model_input_shapes()))
|
||||
return savers
|
||||
|
||||
@@ -279,7 +279,7 @@ def _get_output_head(
|
||||
return module
|
||||
|
||||
|
||||
class ScaledGradHead(HybridBlock):
|
||||
class ScaledGradHead(HybridBlock, utils.OnnxHandlerBlock):
|
||||
"""
|
||||
Wrapper block for applying gradient scaling to input before feeding the head network
|
||||
"""
|
||||
@@ -292,7 +292,7 @@ class ScaledGradHead(HybridBlock):
|
||||
agent_params: AgentParameters,
|
||||
head_params: HeadParameters) -> None:
|
||||
"""
|
||||
:param head_idx: the head index
|
||||
:param head_index: the head index
|
||||
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
|
||||
:param network_name: name of the network
|
||||
:param spaces: state and action space definitions
|
||||
@@ -301,6 +301,7 @@ class ScaledGradHead(HybridBlock):
|
||||
:param head_params: head parameters
|
||||
"""
|
||||
super(ScaledGradHead, self).__init__()
|
||||
utils.OnnxHandlerBlock.__init__(self)
|
||||
|
||||
head_params = _sanitize_activation(head_params)
|
||||
with self.name_scope():
|
||||
@@ -330,7 +331,13 @@ class ScaledGradHead(HybridBlock):
|
||||
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
|
||||
:return: head output
|
||||
"""
|
||||
grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x)
|
||||
if self._onnx:
|
||||
# ONNX doesn't support BlockGrad() operator, but it's not typically needed for
|
||||
# ONNX because mostly forward calls are performed using ONNX exported network.
|
||||
grad_scaled_x = x
|
||||
else:
|
||||
grad_scaled_x = (F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) +
|
||||
F.broadcast_mul(gradient_rescaler, x))
|
||||
out = self.head(grad_scaled_x)
|
||||
return out
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ class PPOVHead(Head):
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final value output of network, of shape (batch_size).
|
||||
"""
|
||||
return self.dense(x).squeeze()
|
||||
return self.dense(x).squeeze(axis=1)
|
||||
|
||||
def loss(self) -> mx.gluon.loss.Loss:
|
||||
"""
|
||||
|
||||
@@ -98,4 +98,4 @@ class VHead(Head):
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final output of value network, of shape (batch_size).
|
||||
"""
|
||||
return self.dense(x).squeeze()
|
||||
return self.dense(x).squeeze(axis=1)
|
||||
|
||||
113
rl_coach/architectures/mxnet_components/savers.py
Normal file
113
rl_coach/architectures/mxnet_components/savers.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from mxnet import gluon, sym
|
||||
from mxnet.contrib import onnx as onnx_mxnet
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.architectures.mxnet_components.utils import ScopedOnnxEnable
|
||||
from rl_coach.saver import Saver
|
||||
|
||||
|
||||
class ParameterDictSaver(Saver):
|
||||
"""
|
||||
Child class that implements saver for mxnet gluon parameter dictionary
|
||||
"""
|
||||
def __init__(self, name: str, param_dict: gluon.ParameterDict):
|
||||
self._name = name
|
||||
self._param_dict = param_dict
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
def save(self, sess: None, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
assert sess is None
|
||||
self._param_dict.save(save_path)
|
||||
return [save_path]
|
||||
|
||||
def restore(self, sess: Any, restore_path: str):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
assert sess is None
|
||||
self._param_dict.load(restore_path)
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
if not isinstance(other, ParameterDictSaver):
|
||||
raise TypeError('merging only supported with ParameterDictSaver (type:{})'.format(type(other)))
|
||||
self._param_dict.update(other._param_dict)
|
||||
|
||||
|
||||
class OnnxSaver(Saver):
|
||||
"""
|
||||
Child class that implements saver for exporting gluon HybridBlock to ONNX
|
||||
"""
|
||||
def __init__(self, name: str, model: gluon.HybridBlock, input_shapes: List[List[int]]):
|
||||
self._name = name
|
||||
self._sym = self._get_onnx_sym(model, len(input_shapes))
|
||||
self._param_dict = model.collect_params()
|
||||
self._input_shapes = input_shapes
|
||||
|
||||
@staticmethod
|
||||
def _get_onnx_sym(model: gluon.HybridBlock, num_inputs: int) -> sym.Symbol:
|
||||
"""
|
||||
Returns a symbolic graph for the model
|
||||
:param model: gluon HybridBlock that constructs the symbolic graph
|
||||
:param num_inputs: number of inputs to the graph
|
||||
:return: symbol for the network
|
||||
"""
|
||||
var_args = [sym.Variable('Data{}'.format(i)) for i in range(num_inputs)]
|
||||
with ScopedOnnxEnable(model):
|
||||
return sym.Group(gluon.block._flatten(model(*var_args), "output")[0])
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
def save(self, sess: None, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF). Must be None.
|
||||
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
assert sess is None
|
||||
params = {name:param._reduce() for name, param in self._param_dict.items()}
|
||||
export_path = onnx_mxnet.export_model(self._sym, params, self._input_shapes, np.float32, save_path)
|
||||
|
||||
return [export_path]
|
||||
|
||||
def restore(self, sess: Any, restore_path: str):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
assert sess is None
|
||||
# Nothing to restore for ONNX
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
# No merging is supported for ONNX. self.path must be unique
|
||||
raise RuntimeError('merging not supported for ONNX exporter')
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet import nd
|
||||
from mxnet import gluon, nd
|
||||
from mxnet.ndarray import NDArray
|
||||
import numpy as np
|
||||
|
||||
@@ -278,3 +278,46 @@ def get_mxnet_activation_name(activation_name: str):
|
||||
"Activation function must be one of the following {}. instead it was: {}".format(
|
||||
activation_functions.keys(), activation_name)
|
||||
return activation_functions[activation_name]
|
||||
|
||||
|
||||
class OnnxHandlerBlock(object):
|
||||
"""
|
||||
Helper base class for gluon blocks that must behave differently for ONNX export forward pass
|
||||
"""
|
||||
def __init__(self):
|
||||
self._onnx = False
|
||||
|
||||
def enable_onnx(self):
|
||||
self._onnx = True
|
||||
|
||||
def disable_onnx(self):
|
||||
self._onnx = False
|
||||
|
||||
|
||||
class ScopedOnnxEnable(object):
|
||||
"""
|
||||
Helper scoped ONNX enable class
|
||||
"""
|
||||
def __init__(self, net: gluon.HybridBlock):
|
||||
self._onnx_handlers = self._get_onnx_handlers(net)
|
||||
|
||||
def __enter__(self):
|
||||
for b in self._onnx_handlers:
|
||||
b.enable_onnx()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for b in self._onnx_handlers:
|
||||
b.disable_onnx()
|
||||
|
||||
@staticmethod
|
||||
def _get_onnx_handlers(block: gluon.HybridBlock) -> List[OnnxHandlerBlock]:
|
||||
"""
|
||||
Iterates through all child blocks and return all of them that are instance of OnnxHandlerBlock
|
||||
:return: list of OnnxHandlerBlock child blocks
|
||||
"""
|
||||
handlers = list()
|
||||
if isinstance(block, OnnxHandlerBlock):
|
||||
handlers.append(block)
|
||||
for child_block in block._children.values():
|
||||
handlers += ScopedOnnxEnable._get_onnx_handlers(child_block)
|
||||
return handlers
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import List, Tuple
|
||||
|
||||
from rl_coach.base_parameters import Frameworks, AgentParameters
|
||||
from rl_coach.logger import failed_imports
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
try:
|
||||
import tensorflow as tf
|
||||
@@ -251,3 +252,25 @@ class NetworkWrapper(object):
|
||||
result.append(str(self.online_network))
|
||||
result.append("")
|
||||
return '\n'.join(result)
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collect all of network's savers for global or online network
|
||||
Note: global, online, and target network are all copies fo the same network which parameters that are
|
||||
updated at different rates. So we only need to save one of the networks; the one that holds the most
|
||||
recent parameters. target network is created for some agents and used for stabilizing training by
|
||||
updating parameters from online network at a slower rate. As a result, target network never contains
|
||||
the most recent set of parameters. In single-worker training, no global network is created and online
|
||||
network contains the most recent parameters. In vertical distributed training with more than one worker,
|
||||
global network is updated by all workers and contains the most recent parameters.
|
||||
Therefore preference is given to global network if it exists, otherwise online network is used
|
||||
for saving.
|
||||
:param parent_path_suffix: path suffix of the parent of the network wrapper
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: collection of all checkpoint objects
|
||||
"""
|
||||
if self.global_network:
|
||||
savers = self.global_network.collect_savers(parent_path_suffix)
|
||||
else:
|
||||
savers = self.online_network.collect_savers(parent_path_suffix)
|
||||
return savers
|
||||
|
||||
@@ -23,6 +23,7 @@ import tensorflow as tf
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
|
||||
|
||||
@@ -637,6 +638,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.curr_rnn_c_in = self.middleware.c_init
|
||||
self.curr_rnn_h_in = self.middleware.h_init
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collection of all checkpoints for the network (typically only one checkpoint)
|
||||
:param parent_path_suffix: path suffix of the parent of the network
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: checkpoint collection for the network
|
||||
"""
|
||||
# TODO implement returning checkpoints for tensorflow
|
||||
return SaverCollection()
|
||||
|
||||
|
||||
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user