mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
112
rl_coach/saver.py
Normal file
112
rl_coach/saver.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Module for abstract base class for checkpoint object and checkpoint collection
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class Saver(object):
|
||||
"""
|
||||
ABC for saver objects that implement saving/restoring to/from path, and merging two savers.
|
||||
"""
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two saver objects return the same path, they must be merge-able.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, sess: Any, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, sess: Any, restore_path: str) -> None:
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def merge(self, other: 'Saver') -> None:
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SaverCollection(object):
|
||||
"""
|
||||
Object for storing a collection of saver objects. It takes care of ensuring uniqueness of saver paths
|
||||
and merging savers if they have the same path. For example, if a saver handles saving a generic key/value
|
||||
file for all networks in a single file, it can use a more generic path and all savers of all networks would be
|
||||
merged into a single saver that saves/restores parameters for all networks.
|
||||
NOTE: If two savers have the same path, the respective saver class must support merging them
|
||||
into a single saver that saves/restores all merged parameters.
|
||||
"""
|
||||
def __init__(self, saver: Saver = None):
|
||||
"""
|
||||
:param saver: optional initial saver for the collection
|
||||
"""
|
||||
self._saver_dict = dict() # type: Dict[str, Saver]
|
||||
if saver is not None:
|
||||
self._saver_dict[saver.path] = saver
|
||||
|
||||
def add(self, saver: Saver):
|
||||
"""
|
||||
Add a new saver to the collection. If saver.path is already in the collection, merge
|
||||
the new saver with the existing saver.
|
||||
:param saver: new saver to be added to collection
|
||||
"""
|
||||
if saver.path in self._saver_dict:
|
||||
self._saver_dict[saver.path].merge(saver)
|
||||
else:
|
||||
self._saver_dict[saver.path] = saver
|
||||
|
||||
def update(self, other: 'SaverCollection'):
|
||||
"""
|
||||
Merge savers from other collection into self
|
||||
:param other: saver collection to update self with.
|
||||
"""
|
||||
for c in other:
|
||||
self.add(c)
|
||||
|
||||
def save(self, sess: Any, save_path: str) -> List[str]:
|
||||
"""
|
||||
Call save on all savers in the collection
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param save_path: path for saving checkpoints using savers. All saved file paths must
|
||||
start with this path in their full path. For example if save_path is '/home/checkpoints/checkpoint-01',
|
||||
then saved file paths can be '/home/checkpoints/checkpoint-01.main-network' but not
|
||||
'/home/checkpoints/main-network'
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
paths = list()
|
||||
for saver in self:
|
||||
paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path)))
|
||||
return paths
|
||||
|
||||
def restore(self, sess: Any, restore_path: str) -> None:
|
||||
"""
|
||||
Call restore on all savers in the collection
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: path for restoring checkpoint using savers.
|
||||
"""
|
||||
for saver in self:
|
||||
restore_path = "{}.{}".format(restore_path, saver.path)
|
||||
saver.restore(sess, restore_path)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Return an iterator for savers in the collection
|
||||
:return: saver iterator
|
||||
"""
|
||||
return (v for v in self._saver_dict.values())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user