mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
113
rl_coach/architectures/mxnet_components/savers.py
Normal file
113
rl_coach/architectures/mxnet_components/savers.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from mxnet import gluon, sym
|
||||
from mxnet.contrib import onnx as onnx_mxnet
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.architectures.mxnet_components.utils import ScopedOnnxEnable
|
||||
from rl_coach.saver import Saver
|
||||
|
||||
|
||||
class ParameterDictSaver(Saver):
|
||||
"""
|
||||
Child class that implements saver for mxnet gluon parameter dictionary
|
||||
"""
|
||||
def __init__(self, name: str, param_dict: gluon.ParameterDict):
|
||||
self._name = name
|
||||
self._param_dict = param_dict
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
def save(self, sess: None, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
assert sess is None
|
||||
self._param_dict.save(save_path)
|
||||
return [save_path]
|
||||
|
||||
def restore(self, sess: Any, restore_path: str):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
assert sess is None
|
||||
self._param_dict.load(restore_path)
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
if not isinstance(other, ParameterDictSaver):
|
||||
raise TypeError('merging only supported with ParameterDictSaver (type:{})'.format(type(other)))
|
||||
self._param_dict.update(other._param_dict)
|
||||
|
||||
|
||||
class OnnxSaver(Saver):
|
||||
"""
|
||||
Child class that implements saver for exporting gluon HybridBlock to ONNX
|
||||
"""
|
||||
def __init__(self, name: str, model: gluon.HybridBlock, input_shapes: List[List[int]]):
|
||||
self._name = name
|
||||
self._sym = self._get_onnx_sym(model, len(input_shapes))
|
||||
self._param_dict = model.collect_params()
|
||||
self._input_shapes = input_shapes
|
||||
|
||||
@staticmethod
|
||||
def _get_onnx_sym(model: gluon.HybridBlock, num_inputs: int) -> sym.Symbol:
|
||||
"""
|
||||
Returns a symbolic graph for the model
|
||||
:param model: gluon HybridBlock that constructs the symbolic graph
|
||||
:param num_inputs: number of inputs to the graph
|
||||
:return: symbol for the network
|
||||
"""
|
||||
var_args = [sym.Variable('Data{}'.format(i)) for i in range(num_inputs)]
|
||||
with ScopedOnnxEnable(model):
|
||||
return sym.Group(gluon.block._flatten(model(*var_args), "output")[0])
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
def save(self, sess: None, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF). Must be None.
|
||||
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
assert sess is None
|
||||
params = {name:param._reduce() for name, param in self._param_dict.items()}
|
||||
export_path = onnx_mxnet.export_model(self._sym, params, self._input_shapes, np.float32, save_path)
|
||||
|
||||
return [export_path]
|
||||
|
||||
def restore(self, sess: Any, restore_path: str):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
assert sess is None
|
||||
# Nothing to restore for ONNX
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
# No merging is supported for ONNX. self.path must be unique
|
||||
raise RuntimeError('merging not supported for ONNX exporter')
|
||||
Reference in New Issue
Block a user