mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -279,7 +279,7 @@ def _get_output_head(
|
||||
return module
|
||||
|
||||
|
||||
class ScaledGradHead(HybridBlock):
|
||||
class ScaledGradHead(HybridBlock, utils.OnnxHandlerBlock):
|
||||
"""
|
||||
Wrapper block for applying gradient scaling to input before feeding the head network
|
||||
"""
|
||||
@@ -292,7 +292,7 @@ class ScaledGradHead(HybridBlock):
|
||||
agent_params: AgentParameters,
|
||||
head_params: HeadParameters) -> None:
|
||||
"""
|
||||
:param head_idx: the head index
|
||||
:param head_index: the head index
|
||||
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
|
||||
:param network_name: name of the network
|
||||
:param spaces: state and action space definitions
|
||||
@@ -301,6 +301,7 @@ class ScaledGradHead(HybridBlock):
|
||||
:param head_params: head parameters
|
||||
"""
|
||||
super(ScaledGradHead, self).__init__()
|
||||
utils.OnnxHandlerBlock.__init__(self)
|
||||
|
||||
head_params = _sanitize_activation(head_params)
|
||||
with self.name_scope():
|
||||
@@ -330,7 +331,13 @@ class ScaledGradHead(HybridBlock):
|
||||
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
|
||||
:return: head output
|
||||
"""
|
||||
grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x)
|
||||
if self._onnx:
|
||||
# ONNX doesn't support BlockGrad() operator, but it's not typically needed for
|
||||
# ONNX because mostly forward calls are performed using ONNX exported network.
|
||||
grad_scaled_x = x
|
||||
else:
|
||||
grad_scaled_x = (F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) +
|
||||
F.broadcast_mul(gradient_rescaler, x))
|
||||
out = self.head(grad_scaled_x)
|
||||
return out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user