mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -19,6 +19,7 @@ import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
@@ -26,7 +27,7 @@ import time
|
||||
import traceback
|
||||
from multiprocessing import Manager
|
||||
from subprocess import Popen
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import atexit
|
||||
import numpy as np
|
||||
@@ -547,3 +548,50 @@ def indent_string(string):
|
||||
return '\t' + string.replace('\n', '\n\t')
|
||||
|
||||
|
||||
class CheckpointState(object):
|
||||
"""
|
||||
Helper class for checkpoint directory information. It replicates
|
||||
the CheckpointState protobuf class in tensorflow.
|
||||
"""
|
||||
def __init__(self, checkpoints: List[str]):
|
||||
self._checkpoints = checkpoints
|
||||
|
||||
@property
|
||||
def all_model_checkpoint_paths(self):
|
||||
return self._checkpoints
|
||||
|
||||
@property
|
||||
def model_checkpoint_path(self):
|
||||
return self._checkpoints[-1]
|
||||
|
||||
def __str__(self):
|
||||
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
|
||||
for c in self._checkpoints:
|
||||
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
|
||||
return out_str
|
||||
|
||||
def __repr__(self):
|
||||
return str(self._checkpoints)
|
||||
|
||||
|
||||
COACH_CHECKPOINT_PATTERN = r'\A([0-9]+)[^0-9].*?\.ckpt'
|
||||
|
||||
|
||||
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
|
||||
CheckpointState:
|
||||
"""
|
||||
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(1)) to sort
|
||||
the checkpoint names and find the latest checkpoint
|
||||
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
|
||||
:param filename_pattern: regex pattern for checkpoint filenames
|
||||
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names
|
||||
"""
|
||||
prog = re.compile(filename_pattern)
|
||||
checkpoints = dict()
|
||||
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
|
||||
for name in filenames:
|
||||
m = prog.search(name)
|
||||
if m is not None and m.group(1) is not None:
|
||||
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
|
||||
checkpoints[int(m.group(1))] = full_path
|
||||
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])
|
||||
|
||||
Reference in New Issue
Block a user