mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -98,4 +98,4 @@ class VHead(Head):
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final output of value network, of shape (batch_size).
|
||||
"""
|
||||
return self.dense(x).squeeze()
|
||||
return self.dense(x).squeeze(axis=1)
|
||||
|
||||
Reference in New Issue
Block a user