1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-18 15:35:56 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Tuple
import numpy as np
from rl_coach.base_parameters import AgentParameters
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
@@ -213,3 +214,12 @@ class Architecture(object):
:param placeholder: a placeholder for binding the value to assign_op.
"""
raise NotImplementedError
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collection of all savers for the network (typically only one saver for network and one for ONNX export)
:param parent_path_suffix: path suffix of the parent of the network
(e.g. could be name of level manager plus name of agent)
:return: saver collection for the network
"""
raise NotImplementedError

View File

@@ -24,7 +24,9 @@ from mxnet.ndarray import NDArray
from rl_coach.architectures.architecture import Architecture
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
from rl_coach.base_parameters import AgentParameters
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list, squeeze_list
@@ -81,17 +83,25 @@ class MxnetArchitecture(Architecture):
"""
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
def _model_input_shapes(self) -> List[List[int]]:
"""
Create a list of input array shapes
:return: type of input shapes
"""
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
allowed_inputs["action"] = copy.copy(self.spaces.action)
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
embedders = self.model.nets[0].input_embedders
return list([1] + allowed_inputs[emb.embedder_name].shape.tolist() for emb in embedders)
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
"""
Creates a tuple of input arrays with correct shapes that can be used for shape inference
of the model weights and for printing the summary
:return: tuple of inputs for model forward pass
"""
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
allowed_inputs["action"] = copy.copy(self.spaces.action)
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
embedders = self.model.nets[0].input_embedders
inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders)
input_shapes = self._model_input_shapes()
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
return inputs
def construct_model(self) -> None:
@@ -402,3 +412,21 @@ class MxnetArchitecture(Architecture):
:return: None
"""
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collection of all checkpoints for the network (typically only one checkpoint)
:param parent_path_suffix: path suffix of the parent of the network
(e.g. could be name of level manager plus name of agent)
:return: checkpoint collection for the network
"""
name = self.name.replace('/', '.')
savers = SaverCollection(ParameterDictSaver(
name="{}.{}".format(parent_path_suffix, name),
param_dict=self.model.collect_params()))
if self.ap.task_parameters.export_onnx_graph:
savers.add(OnnxSaver(
name="{}.{}.onnx".format(parent_path_suffix, name),
model=self.model,
input_shapes=self._model_input_shapes()))
return savers

View File

@@ -279,7 +279,7 @@ def _get_output_head(
return module
class ScaledGradHead(HybridBlock):
class ScaledGradHead(HybridBlock, utils.OnnxHandlerBlock):
"""
Wrapper block for applying gradient scaling to input before feeding the head network
"""
@@ -292,7 +292,7 @@ class ScaledGradHead(HybridBlock):
agent_params: AgentParameters,
head_params: HeadParameters) -> None:
"""
:param head_idx: the head index
:param head_index: the head index
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
:param network_name: name of the network
:param spaces: state and action space definitions
@@ -301,6 +301,7 @@ class ScaledGradHead(HybridBlock):
:param head_params: head parameters
"""
super(ScaledGradHead, self).__init__()
utils.OnnxHandlerBlock.__init__(self)
head_params = _sanitize_activation(head_params)
with self.name_scope():
@@ -330,7 +331,13 @@ class ScaledGradHead(HybridBlock):
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
:return: head output
"""
grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x)
if self._onnx:
# ONNX doesn't support BlockGrad() operator, but it's not typically needed for
# ONNX because mostly forward calls are performed using ONNX exported network.
grad_scaled_x = x
else:
grad_scaled_x = (F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) +
F.broadcast_mul(gradient_rescaler, x))
out = self.head(grad_scaled_x)
return out

View File

@@ -113,7 +113,7 @@ class PPOVHead(Head):
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final value output of network, of shape (batch_size).
"""
return self.dense(x).squeeze()
return self.dense(x).squeeze(axis=1)
def loss(self) -> mx.gluon.loss.Loss:
"""

View File

@@ -98,4 +98,4 @@ class VHead(Head):
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final output of value network, of shape (batch_size).
"""
return self.dense(x).squeeze()
return self.dense(x).squeeze(axis=1)

View File

@@ -0,0 +1,113 @@
from typing import Any, List, Tuple
from mxnet import gluon, sym
from mxnet.contrib import onnx as onnx_mxnet
import numpy as np
from rl_coach.architectures.mxnet_components.utils import ScopedOnnxEnable
from rl_coach.saver import Saver
class ParameterDictSaver(Saver):
"""
Child class that implements saver for mxnet gluon parameter dictionary
"""
def __init__(self, name: str, param_dict: gluon.ParameterDict):
self._name = name
self._param_dict = param_dict
@property
def path(self):
"""
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
"""
return self._name
def save(self, sess: None, save_path: str) -> List[str]:
"""
Save to save_path
:param sess: active session for session-based frameworks (e.g. TF)
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
:return: list of all saved paths
"""
assert sess is None
self._param_dict.save(save_path)
return [save_path]
def restore(self, sess: Any, restore_path: str):
"""
Restore from restore_path
:param sess: active session for session-based frameworks (e.g. TF)
:param restore_path: full path to load checkpoint from.
"""
assert sess is None
self._param_dict.load(restore_path)
def merge(self, other: 'Saver'):
"""
Merge other saver into this saver
:param other: saver to be merged into self
"""
if not isinstance(other, ParameterDictSaver):
raise TypeError('merging only supported with ParameterDictSaver (type:{})'.format(type(other)))
self._param_dict.update(other._param_dict)
class OnnxSaver(Saver):
"""
Child class that implements saver for exporting gluon HybridBlock to ONNX
"""
def __init__(self, name: str, model: gluon.HybridBlock, input_shapes: List[List[int]]):
self._name = name
self._sym = self._get_onnx_sym(model, len(input_shapes))
self._param_dict = model.collect_params()
self._input_shapes = input_shapes
@staticmethod
def _get_onnx_sym(model: gluon.HybridBlock, num_inputs: int) -> sym.Symbol:
"""
Returns a symbolic graph for the model
:param model: gluon HybridBlock that constructs the symbolic graph
:param num_inputs: number of inputs to the graph
:return: symbol for the network
"""
var_args = [sym.Variable('Data{}'.format(i)) for i in range(num_inputs)]
with ScopedOnnxEnable(model):
return sym.Group(gluon.block._flatten(model(*var_args), "output")[0])
@property
def path(self):
"""
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
"""
return self._name
def save(self, sess: None, save_path: str) -> List[str]:
"""
Save to save_path
:param sess: active session for session-based frameworks (e.g. TF). Must be None.
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
:return: list of all saved paths
"""
assert sess is None
params = {name:param._reduce() for name, param in self._param_dict.items()}
export_path = onnx_mxnet.export_model(self._sym, params, self._input_shapes, np.float32, save_path)
return [export_path]
def restore(self, sess: Any, restore_path: str):
"""
Restore from restore_path
:param sess: active session for session-based frameworks (e.g. TF)
:param restore_path: full path to load checkpoint from.
"""
assert sess is None
# Nothing to restore for ONNX
def merge(self, other: 'Saver'):
"""
Merge other saver into this saver
:param other: saver to be merged into self
"""
# No merging is supported for ONNX. self.path must be unique
raise RuntimeError('merging not supported for ONNX exporter')

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
from types import ModuleType
import mxnet as mx
from mxnet import nd
from mxnet import gluon, nd
from mxnet.ndarray import NDArray
import numpy as np
@@ -278,3 +278,46 @@ def get_mxnet_activation_name(activation_name: str):
"Activation function must be one of the following {}. instead it was: {}".format(
activation_functions.keys(), activation_name)
return activation_functions[activation_name]
class OnnxHandlerBlock(object):
"""
Helper base class for gluon blocks that must behave differently for ONNX export forward pass
"""
def __init__(self):
self._onnx = False
def enable_onnx(self):
self._onnx = True
def disable_onnx(self):
self._onnx = False
class ScopedOnnxEnable(object):
"""
Helper scoped ONNX enable class
"""
def __init__(self, net: gluon.HybridBlock):
self._onnx_handlers = self._get_onnx_handlers(net)
def __enter__(self):
for b in self._onnx_handlers:
b.enable_onnx()
def __exit__(self, exc_type, exc_val, exc_tb):
for b in self._onnx_handlers:
b.disable_onnx()
@staticmethod
def _get_onnx_handlers(block: gluon.HybridBlock) -> List[OnnxHandlerBlock]:
"""
Iterates through all child blocks and return all of them that are instance of OnnxHandlerBlock
:return: list of OnnxHandlerBlock child blocks
"""
handlers = list()
if isinstance(block, OnnxHandlerBlock):
handlers.append(block)
for child_block in block._children.values():
handlers += ScopedOnnxEnable._get_onnx_handlers(child_block)
return handlers

View File

@@ -18,6 +18,7 @@ from typing import List, Tuple
from rl_coach.base_parameters import Frameworks, AgentParameters
from rl_coach.logger import failed_imports
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
try:
import tensorflow as tf
@@ -251,3 +252,25 @@ class NetworkWrapper(object):
result.append(str(self.online_network))
result.append("")
return '\n'.join(result)
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collect all of network's savers for global or online network
Note: global, online, and target network are all copies fo the same network which parameters that are
updated at different rates. So we only need to save one of the networks; the one that holds the most
recent parameters. target network is created for some agents and used for stabilizing training by
updating parameters from online network at a slower rate. As a result, target network never contains
the most recent set of parameters. In single-worker training, no global network is created and online
network contains the most recent parameters. In vertical distributed training with more than one worker,
global network is updated by all workers and contains the most recent parameters.
Therefore preference is given to global network if it exists, otherwise online network is used
for saving.
:param parent_path_suffix: path suffix of the parent of the network wrapper
(e.g. could be name of level manager plus name of agent)
:return: collection of all checkpoint objects
"""
if self.global_network:
savers = self.global_network.collect_savers(parent_path_suffix)
else:
savers = self.online_network.collect_savers(parent_path_suffix)
return savers

View File

@@ -23,6 +23,7 @@ import tensorflow as tf
from rl_coach.architectures.architecture import Architecture
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.core_types import GradientClippingMethod
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
@@ -637,6 +638,16 @@ class TensorFlowArchitecture(Architecture):
self.curr_rnn_c_in = self.middleware.c_init
self.curr_rnn_h_in = self.middleware.h_init
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collection of all checkpoints for the network (typically only one checkpoint)
:param parent_path_suffix: path suffix of the parent of the network
(e.g. could be name of level manager plus name of agent)
:return: checkpoint collection for the network
"""
# TODO implement returning checkpoints for tensorflow
return SaverCollection()
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
"""