1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -98,4 +98,4 @@ class VHead(Head):
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final output of value network, of shape (batch_size).
"""
return self.dense(x).squeeze()
return self.dense(x).squeeze(axis=1)