1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
from types import ModuleType
import mxnet as mx
from mxnet import nd
from mxnet import gluon, nd
from mxnet.ndarray import NDArray
import numpy as np
@@ -278,3 +278,46 @@ def get_mxnet_activation_name(activation_name: str):
"Activation function must be one of the following {}. instead it was: {}".format(
activation_functions.keys(), activation_name)
return activation_functions[activation_name]
class OnnxHandlerBlock(object):
"""
Helper base class for gluon blocks that must behave differently for ONNX export forward pass
"""
def __init__(self):
self._onnx = False
def enable_onnx(self):
self._onnx = True
def disable_onnx(self):
self._onnx = False
class ScopedOnnxEnable(object):
"""
Helper scoped ONNX enable class
"""
def __init__(self, net: gluon.HybridBlock):
self._onnx_handlers = self._get_onnx_handlers(net)
def __enter__(self):
for b in self._onnx_handlers:
b.enable_onnx()
def __exit__(self, exc_type, exc_val, exc_tb):
for b in self._onnx_handlers:
b.disable_onnx()
@staticmethod
def _get_onnx_handlers(block: gluon.HybridBlock) -> List[OnnxHandlerBlock]:
"""
Iterates through all child blocks and return all of them that are instance of OnnxHandlerBlock
:return: list of OnnxHandlerBlock child blocks
"""
handlers = list()
if isinstance(block, OnnxHandlerBlock):
handlers.append(block)
for child_block in block._children.values():
handlers += ScopedOnnxEnable._get_onnx_handlers(child_block)
return handlers