1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -279,7 +279,7 @@ def _get_output_head(
return module
class ScaledGradHead(HybridBlock):
class ScaledGradHead(HybridBlock, utils.OnnxHandlerBlock):
"""
Wrapper block for applying gradient scaling to input before feeding the head network
"""
@@ -292,7 +292,7 @@ class ScaledGradHead(HybridBlock):
agent_params: AgentParameters,
head_params: HeadParameters) -> None:
"""
:param head_idx: the head index
:param head_index: the head index
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
:param network_name: name of the network
:param spaces: state and action space definitions
@@ -301,6 +301,7 @@ class ScaledGradHead(HybridBlock):
:param head_params: head parameters
"""
super(ScaledGradHead, self).__init__()
utils.OnnxHandlerBlock.__init__(self)
head_params = _sanitize_activation(head_params)
with self.name_scope():
@@ -330,7 +331,13 @@ class ScaledGradHead(HybridBlock):
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
:return: head output
"""
grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x)
if self._onnx:
# ONNX doesn't support BlockGrad() operator, but it's not typically needed for
# ONNX because mostly forward calls are performed using ONNX exported network.
grad_scaled_x = x
else:
grad_scaled_x = (F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) +
F.broadcast_mul(gradient_rescaler, x))
out = self.head(grad_scaled_x)
return out