""" Module for abstract base class for checkpoint object and checkpoint collection """ from typing import Any, Dict, List class Saver(object): """ ABC for saver objects that implement saving/restoring to/from path, and merging two savers. """ @property def path(self): """ Relative path for save/load. If two saver objects return the same path, they must be merge-able. """ raise NotImplementedError def save(self, sess: Any, save_path: str) -> List[str]: """ Save to save_path :param sess: active session for session-based frameworks (e.g. TF) :param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count). :return: list of all saved paths """ raise NotImplementedError def restore(self, sess: Any, restore_path: str) -> None: """ Restore from restore_path :param sess: active session for session-based frameworks (e.g. TF) :param restore_path: full path to load checkpoint from. """ raise NotImplementedError def merge(self, other: 'Saver') -> None: """ Merge other saver into this saver :param other: saver to be merged into self """ raise NotImplementedError class SaverCollection(object): """ Object for storing a collection of saver objects. It takes care of ensuring uniqueness of saver paths and merging savers if they have the same path. For example, if a saver handles saving a generic key/value file for all networks in a single file, it can use a more generic path and all savers of all networks would be merged into a single saver that saves/restores parameters for all networks. NOTE: If two savers have the same path, the respective saver class must support merging them into a single saver that saves/restores all merged parameters. """ def __init__(self, saver: Saver = None): """ :param saver: optional initial saver for the collection """ self._saver_dict = dict() # type: Dict[str, Saver] if saver is not None: self._saver_dict[saver.path] = saver def add(self, saver: Saver): """ Add a new saver to the collection. If saver.path is already in the collection, merge the new saver with the existing saver. :param saver: new saver to be added to collection """ if saver.path in self._saver_dict: self._saver_dict[saver.path].merge(saver) else: self._saver_dict[saver.path] = saver def update(self, other: 'SaverCollection'): """ Merge savers from other collection into self :param other: saver collection to update self with. """ for c in other: self.add(c) def save(self, sess: Any, save_path: str) -> List[str]: """ Call save on all savers in the collection :param sess: active session for session-based frameworks (e.g. TF) :param save_path: path for saving checkpoints using savers. All saved file paths must start with this path in their full path. For example if save_path is '/home/checkpoints/checkpoint-01', then saved file paths can be '/home/checkpoints/checkpoint-01.main-network' but not '/home/checkpoints/main-network' :return: list of all saved paths """ paths = list() for saver in self: paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path))) return paths def restore(self, sess: Any, restore_path: str) -> None: """ Call restore on all savers in the collection :param sess: active session for session-based frameworks (e.g. TF) :param restore_path: path for restoring checkpoint using savers. """ for saver in self: restore_path = "{}.{}".format(restore_path, saver.path) saver.restore(sess, restore_path) def __iter__(self): """ Return an iterator for savers in the collection :return: saver iterator """ return (v for v in self._saver_dict.values())