mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet import nd
|
||||
from mxnet import gluon, nd
|
||||
from mxnet.ndarray import NDArray
|
||||
import numpy as np
|
||||
|
||||
@@ -278,3 +278,46 @@ def get_mxnet_activation_name(activation_name: str):
|
||||
"Activation function must be one of the following {}. instead it was: {}".format(
|
||||
activation_functions.keys(), activation_name)
|
||||
return activation_functions[activation_name]
|
||||
|
||||
|
||||
class OnnxHandlerBlock(object):
|
||||
"""
|
||||
Helper base class for gluon blocks that must behave differently for ONNX export forward pass
|
||||
"""
|
||||
def __init__(self):
|
||||
self._onnx = False
|
||||
|
||||
def enable_onnx(self):
|
||||
self._onnx = True
|
||||
|
||||
def disable_onnx(self):
|
||||
self._onnx = False
|
||||
|
||||
|
||||
class ScopedOnnxEnable(object):
|
||||
"""
|
||||
Helper scoped ONNX enable class
|
||||
"""
|
||||
def __init__(self, net: gluon.HybridBlock):
|
||||
self._onnx_handlers = self._get_onnx_handlers(net)
|
||||
|
||||
def __enter__(self):
|
||||
for b in self._onnx_handlers:
|
||||
b.enable_onnx()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for b in self._onnx_handlers:
|
||||
b.disable_onnx()
|
||||
|
||||
@staticmethod
|
||||
def _get_onnx_handlers(block: gluon.HybridBlock) -> List[OnnxHandlerBlock]:
|
||||
"""
|
||||
Iterates through all child blocks and return all of them that are instance of OnnxHandlerBlock
|
||||
:return: list of OnnxHandlerBlock child blocks
|
||||
"""
|
||||
handlers = list()
|
||||
if isinstance(block, OnnxHandlerBlock):
|
||||
handlers.append(block)
|
||||
for child_block in block._children.values():
|
||||
handlers += ScopedOnnxEnable._get_onnx_handlers(child_block)
|
||||
return handlers
|
||||
|
||||
Reference in New Issue
Block a user