1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -30,6 +30,7 @@ from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, A
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
from rl_coach.logger import screen, Logger, EpisodeLogger
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplay
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
from rl_coach.utils import Signal, force_list
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
@@ -996,3 +997,16 @@ class Agent(AgentInterface):
def get_success_rate(self) -> float:
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collect all of agent's network savers
:param parent_path_suffix: path suffix of the parent of the agent
(could be name of level manager or composite agent)
:return: collection of all agent savers
"""
parent_path_suffix = "{}.{}".format(parent_path_suffix, self.name)
savers = SaverCollection()
for network in self.networks.values():
savers.update(network.collect_savers(parent_path_suffix))
return savers

View File

@@ -19,6 +19,7 @@ from typing import Union, List, Dict
import numpy as np
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition
from rl_coach.saver import SaverCollection
class AgentInterface(object):
@@ -153,3 +154,12 @@ class AgentInterface(object):
:return: A tuple containing the actual action and additional info on the action
"""
raise NotImplementedError("")
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collect all of agent savers
:param parent_path_suffix: path suffix of the parent of the agent
(could be name of level manager or composite agent)
:return: collection of all agent savers
"""
raise NotImplementedError

View File

@@ -25,6 +25,7 @@ from rl_coach.agents.agent_interface import AgentInterface
from rl_coach.base_parameters import AgentParameters, VisualizationParameters
from rl_coach.core_types import ActionInfo, EnvResponse, ActionType, RunPhase
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
from rl_coach.saver import SaverCollection
from rl_coach.spaces import ActionSpace
from rl_coach.spaces import AgentSelection, AttentionActionSpace, SpacesDefinition
from rl_coach.utils import short_dynamic_import
@@ -412,3 +413,16 @@ class CompositeAgent(AgentInterface):
:return:
"""
[agent.sync() for agent in self.agents.values()]
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collect all of agent's network savers
:param parent_path_suffix: path suffix of the parent of the agent
(could be name of level manager or composite agent)
:return: collection of all agent savers
"""
savers = SaverCollection()
for agent in self.agents.values():
savers.update(agent.collect_savers(
parent_path_suffix="{}.{}".format(parent_path_suffix, self.name)))
return savers

View File

@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Tuple
import numpy as np
from rl_coach.base_parameters import AgentParameters
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
@@ -213,3 +214,12 @@ class Architecture(object):
:param placeholder: a placeholder for binding the value to assign_op.
"""
raise NotImplementedError
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collection of all savers for the network (typically only one saver for network and one for ONNX export)
:param parent_path_suffix: path suffix of the parent of the network
(e.g. could be name of level manager plus name of agent)
:return: saver collection for the network
"""
raise NotImplementedError

View File

@@ -24,7 +24,9 @@ from mxnet.ndarray import NDArray
from rl_coach.architectures.architecture import Architecture
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
from rl_coach.base_parameters import AgentParameters
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list, squeeze_list
@@ -81,17 +83,25 @@ class MxnetArchitecture(Architecture):
"""
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
def _model_input_shapes(self) -> List[List[int]]:
"""
Create a list of input array shapes
:return: type of input shapes
"""
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
allowed_inputs["action"] = copy.copy(self.spaces.action)
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
embedders = self.model.nets[0].input_embedders
return list([1] + allowed_inputs[emb.embedder_name].shape.tolist() for emb in embedders)
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
"""
Creates a tuple of input arrays with correct shapes that can be used for shape inference
of the model weights and for printing the summary
:return: tuple of inputs for model forward pass
"""
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
allowed_inputs["action"] = copy.copy(self.spaces.action)
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
embedders = self.model.nets[0].input_embedders
inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders)
input_shapes = self._model_input_shapes()
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
return inputs
def construct_model(self) -> None:
@@ -402,3 +412,21 @@ class MxnetArchitecture(Architecture):
:return: None
"""
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collection of all checkpoints for the network (typically only one checkpoint)
:param parent_path_suffix: path suffix of the parent of the network
(e.g. could be name of level manager plus name of agent)
:return: checkpoint collection for the network
"""
name = self.name.replace('/', '.')
savers = SaverCollection(ParameterDictSaver(
name="{}.{}".format(parent_path_suffix, name),
param_dict=self.model.collect_params()))
if self.ap.task_parameters.export_onnx_graph:
savers.add(OnnxSaver(
name="{}.{}.onnx".format(parent_path_suffix, name),
model=self.model,
input_shapes=self._model_input_shapes()))
return savers

View File

@@ -279,7 +279,7 @@ def _get_output_head(
return module
class ScaledGradHead(HybridBlock):
class ScaledGradHead(HybridBlock, utils.OnnxHandlerBlock):
"""
Wrapper block for applying gradient scaling to input before feeding the head network
"""
@@ -292,7 +292,7 @@ class ScaledGradHead(HybridBlock):
agent_params: AgentParameters,
head_params: HeadParameters) -> None:
"""
:param head_idx: the head index
:param head_index: the head index
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
:param network_name: name of the network
:param spaces: state and action space definitions
@@ -301,6 +301,7 @@ class ScaledGradHead(HybridBlock):
:param head_params: head parameters
"""
super(ScaledGradHead, self).__init__()
utils.OnnxHandlerBlock.__init__(self)
head_params = _sanitize_activation(head_params)
with self.name_scope():
@@ -330,7 +331,13 @@ class ScaledGradHead(HybridBlock):
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
:return: head output
"""
grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x)
if self._onnx:
# ONNX doesn't support BlockGrad() operator, but it's not typically needed for
# ONNX because mostly forward calls are performed using ONNX exported network.
grad_scaled_x = x
else:
grad_scaled_x = (F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) +
F.broadcast_mul(gradient_rescaler, x))
out = self.head(grad_scaled_x)
return out

View File

@@ -113,7 +113,7 @@ class PPOVHead(Head):
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final value output of network, of shape (batch_size).
"""
return self.dense(x).squeeze()
return self.dense(x).squeeze(axis=1)
def loss(self) -> mx.gluon.loss.Loss:
"""

View File

@@ -98,4 +98,4 @@ class VHead(Head):
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final output of value network, of shape (batch_size).
"""
return self.dense(x).squeeze()
return self.dense(x).squeeze(axis=1)

View File

@@ -0,0 +1,113 @@
from typing import Any, List, Tuple
from mxnet import gluon, sym
from mxnet.contrib import onnx as onnx_mxnet
import numpy as np
from rl_coach.architectures.mxnet_components.utils import ScopedOnnxEnable
from rl_coach.saver import Saver
class ParameterDictSaver(Saver):
"""
Child class that implements saver for mxnet gluon parameter dictionary
"""
def __init__(self, name: str, param_dict: gluon.ParameterDict):
self._name = name
self._param_dict = param_dict
@property
def path(self):
"""
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
"""
return self._name
def save(self, sess: None, save_path: str) -> List[str]:
"""
Save to save_path
:param sess: active session for session-based frameworks (e.g. TF)
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
:return: list of all saved paths
"""
assert sess is None
self._param_dict.save(save_path)
return [save_path]
def restore(self, sess: Any, restore_path: str):
"""
Restore from restore_path
:param sess: active session for session-based frameworks (e.g. TF)
:param restore_path: full path to load checkpoint from.
"""
assert sess is None
self._param_dict.load(restore_path)
def merge(self, other: 'Saver'):
"""
Merge other saver into this saver
:param other: saver to be merged into self
"""
if not isinstance(other, ParameterDictSaver):
raise TypeError('merging only supported with ParameterDictSaver (type:{})'.format(type(other)))
self._param_dict.update(other._param_dict)
class OnnxSaver(Saver):
"""
Child class that implements saver for exporting gluon HybridBlock to ONNX
"""
def __init__(self, name: str, model: gluon.HybridBlock, input_shapes: List[List[int]]):
self._name = name
self._sym = self._get_onnx_sym(model, len(input_shapes))
self._param_dict = model.collect_params()
self._input_shapes = input_shapes
@staticmethod
def _get_onnx_sym(model: gluon.HybridBlock, num_inputs: int) -> sym.Symbol:
"""
Returns a symbolic graph for the model
:param model: gluon HybridBlock that constructs the symbolic graph
:param num_inputs: number of inputs to the graph
:return: symbol for the network
"""
var_args = [sym.Variable('Data{}'.format(i)) for i in range(num_inputs)]
with ScopedOnnxEnable(model):
return sym.Group(gluon.block._flatten(model(*var_args), "output")[0])
@property
def path(self):
"""
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
"""
return self._name
def save(self, sess: None, save_path: str) -> List[str]:
"""
Save to save_path
:param sess: active session for session-based frameworks (e.g. TF). Must be None.
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
:return: list of all saved paths
"""
assert sess is None
params = {name:param._reduce() for name, param in self._param_dict.items()}
export_path = onnx_mxnet.export_model(self._sym, params, self._input_shapes, np.float32, save_path)
return [export_path]
def restore(self, sess: Any, restore_path: str):
"""
Restore from restore_path
:param sess: active session for session-based frameworks (e.g. TF)
:param restore_path: full path to load checkpoint from.
"""
assert sess is None
# Nothing to restore for ONNX
def merge(self, other: 'Saver'):
"""
Merge other saver into this saver
:param other: saver to be merged into self
"""
# No merging is supported for ONNX. self.path must be unique
raise RuntimeError('merging not supported for ONNX exporter')

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
from types import ModuleType
import mxnet as mx
from mxnet import nd
from mxnet import gluon, nd
from mxnet.ndarray import NDArray
import numpy as np
@@ -278,3 +278,46 @@ def get_mxnet_activation_name(activation_name: str):
"Activation function must be one of the following {}. instead it was: {}".format(
activation_functions.keys(), activation_name)
return activation_functions[activation_name]
class OnnxHandlerBlock(object):
"""
Helper base class for gluon blocks that must behave differently for ONNX export forward pass
"""
def __init__(self):
self._onnx = False
def enable_onnx(self):
self._onnx = True
def disable_onnx(self):
self._onnx = False
class ScopedOnnxEnable(object):
"""
Helper scoped ONNX enable class
"""
def __init__(self, net: gluon.HybridBlock):
self._onnx_handlers = self._get_onnx_handlers(net)
def __enter__(self):
for b in self._onnx_handlers:
b.enable_onnx()
def __exit__(self, exc_type, exc_val, exc_tb):
for b in self._onnx_handlers:
b.disable_onnx()
@staticmethod
def _get_onnx_handlers(block: gluon.HybridBlock) -> List[OnnxHandlerBlock]:
"""
Iterates through all child blocks and return all of them that are instance of OnnxHandlerBlock
:return: list of OnnxHandlerBlock child blocks
"""
handlers = list()
if isinstance(block, OnnxHandlerBlock):
handlers.append(block)
for child_block in block._children.values():
handlers += ScopedOnnxEnable._get_onnx_handlers(child_block)
return handlers

View File

@@ -18,6 +18,7 @@ from typing import List, Tuple
from rl_coach.base_parameters import Frameworks, AgentParameters
from rl_coach.logger import failed_imports
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
try:
import tensorflow as tf
@@ -251,3 +252,25 @@ class NetworkWrapper(object):
result.append(str(self.online_network))
result.append("")
return '\n'.join(result)
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collect all of network's savers for global or online network
Note: global, online, and target network are all copies fo the same network which parameters that are
updated at different rates. So we only need to save one of the networks; the one that holds the most
recent parameters. target network is created for some agents and used for stabilizing training by
updating parameters from online network at a slower rate. As a result, target network never contains
the most recent set of parameters. In single-worker training, no global network is created and online
network contains the most recent parameters. In vertical distributed training with more than one worker,
global network is updated by all workers and contains the most recent parameters.
Therefore preference is given to global network if it exists, otherwise online network is used
for saving.
:param parent_path_suffix: path suffix of the parent of the network wrapper
(e.g. could be name of level manager plus name of agent)
:return: collection of all checkpoint objects
"""
if self.global_network:
savers = self.global_network.collect_savers(parent_path_suffix)
else:
savers = self.online_network.collect_savers(parent_path_suffix)
return savers

View File

@@ -23,6 +23,7 @@ import tensorflow as tf
from rl_coach.architectures.architecture import Architecture
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.core_types import GradientClippingMethod
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
@@ -637,6 +638,16 @@ class TensorFlowArchitecture(Architecture):
self.curr_rnn_c_in = self.middleware.c_init
self.curr_rnn_h_in = self.middleware.h_init
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collection of all checkpoints for the network (typically only one checkpoint)
:param parent_path_suffix: path suffix of the parent of the network
(e.g. could be name of level manager plus name of agent)
:return: checkpoint collection for the network
"""
# TODO implement returning checkpoints for tensorflow
return SaverCollection()
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
"""

View File

@@ -31,7 +31,8 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T
from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.saver import SaverCollection
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
@@ -87,7 +88,7 @@ class GraphManager(object):
schedule_params: ScheduleParameters,
vis_params: VisualizationParameters = VisualizationParameters()):
self.sess = None
self.level_managers = []
self.level_managers = [] # type: List[LevelManager]
self.top_level_manager = None
self.environments = []
self.heatup_steps = schedule_params.heatup_steps
@@ -248,12 +249,22 @@ class GraphManager(object):
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
self.save_graph()
def _create_session_mx(self):
"""
Call set_session to initialize parameters and construct checkpoint_saver
"""
self.set_session(sess=None) # Initialize all modules
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())
# restore from checkpoint if given
self.restore_checkpoint()
def create_session(self, task_parameters: TaskParameters):
if task_parameters.framework_type == Frameworks.tensorflow:
self._create_session_tf(task_parameters)
elif task_parameters.framework_type == Frameworks.mxnet:
self.set_session(sess=None) # Initialize all modules
# TODO add checkpoint loading
self._create_session_mx()
else:
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
@@ -270,14 +281,13 @@ class GraphManager(object):
name='graphdef.pb',
as_text=False)
def save_onnx_graph(self) -> None:
def _save_onnx_graph_tf(self) -> None:
"""
Save the graph as an ONNX graph.
Save the tensorflow graph as an ONNX graph.
This requires the graph and the weights checkpoint to be stored in the experiment directory.
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
:return: None
"""
# collect input and output nodes
input_nodes = []
output_nodes = []
@@ -290,11 +300,20 @@ class GraphManager(object):
for output in network.online_network.outputs:
output_nodes.append(output.name)
# TODO: make this framework agnostic
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
def save_onnx_graph(self) -> None:
"""
Save the graph as an ONNX graph.
This requires the graph and the weights checkpoint to be stored in the experiment directory.
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
:return: None
"""
if self.task_parameters.framework_type == Frameworks.tensorflow:
self._save_onnx_graph_tf()
def setup_logger(self) -> None:
# dump documentation
logger_prefix = "{graph_name}".format(graph_name=self.name)
@@ -526,14 +545,13 @@ class GraphManager(object):
if self.evaluate(self.evaluation_steps):
break
def _restore_checkpoint_tf(self, checkpoint_dir: str):
def _restore_checkpoint_tf(self, checkpoint_path: str):
import tensorflow as tf
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
variables = {}
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
reader = tf.contrib.framework.load_checkpoint(checkpoint_path)
for var_name, _ in reader.get_variable_to_shape_map().items():
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
var = reader.get_tensor(var_name)
# Set the new name
new_name = var_name
@@ -548,11 +566,14 @@ class GraphManager(object):
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
if self.task_parameters.framework_type == Frameworks.tensorflow:
self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir)
self._restore_checkpoint_tf(checkpoint.model_checkpoint_path)
elif self.task_parameters.framework_type == Frameworks.mxnet:
# TODO implement checkpoint restore
pass
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
else:
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
@@ -572,6 +593,8 @@ class GraphManager(object):
"{}_Step-{}.ckpt".format(
self.checkpoint_id,
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
if not isinstance(self.task_parameters, DistributedTaskParameters):
if self.checkpoint_saver is not None:
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)

View File

@@ -20,6 +20,7 @@ from rl_coach.agents.composite_agent import CompositeAgent
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps, Transition
from rl_coach.environments.environment import Environment
from rl_coach.environments.environment_interface import EnvironmentInterface
from rl_coach.saver import SaverCollection
from rl_coach.spaces import ActionSpace, SpacesDefinition
@@ -292,3 +293,13 @@ class LevelManager(EnvironmentInterface):
def should_stop(self) -> bool:
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])
def collect_savers(self) -> SaverCollection:
"""
Calls collect_savers() on all agents and combines the results to a single collection
:return: saver collection of all agent savers
"""
savers = SaverCollection()
for agent in self.agents.values():
savers.update(agent.collect_savers(parent_path_suffix=self.name))
return savers

112
rl_coach/saver.py Normal file
View File

@@ -0,0 +1,112 @@
"""
Module for abstract base class for checkpoint object and checkpoint collection
"""
from typing import Any, Dict, List
class Saver(object):
"""
ABC for saver objects that implement saving/restoring to/from path, and merging two savers.
"""
@property
def path(self):
"""
Relative path for save/load. If two saver objects return the same path, they must be merge-able.
"""
raise NotImplementedError
def save(self, sess: Any, save_path: str) -> List[str]:
"""
Save to save_path
:param sess: active session for session-based frameworks (e.g. TF)
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
:return: list of all saved paths
"""
raise NotImplementedError
def restore(self, sess: Any, restore_path: str) -> None:
"""
Restore from restore_path
:param sess: active session for session-based frameworks (e.g. TF)
:param restore_path: full path to load checkpoint from.
"""
raise NotImplementedError
def merge(self, other: 'Saver') -> None:
"""
Merge other saver into this saver
:param other: saver to be merged into self
"""
raise NotImplementedError
class SaverCollection(object):
"""
Object for storing a collection of saver objects. It takes care of ensuring uniqueness of saver paths
and merging savers if they have the same path. For example, if a saver handles saving a generic key/value
file for all networks in a single file, it can use a more generic path and all savers of all networks would be
merged into a single saver that saves/restores parameters for all networks.
NOTE: If two savers have the same path, the respective saver class must support merging them
into a single saver that saves/restores all merged parameters.
"""
def __init__(self, saver: Saver = None):
"""
:param saver: optional initial saver for the collection
"""
self._saver_dict = dict() # type: Dict[str, Saver]
if saver is not None:
self._saver_dict[saver.path] = saver
def add(self, saver: Saver):
"""
Add a new saver to the collection. If saver.path is already in the collection, merge
the new saver with the existing saver.
:param saver: new saver to be added to collection
"""
if saver.path in self._saver_dict:
self._saver_dict[saver.path].merge(saver)
else:
self._saver_dict[saver.path] = saver
def update(self, other: 'SaverCollection'):
"""
Merge savers from other collection into self
:param other: saver collection to update self with.
"""
for c in other:
self.add(c)
def save(self, sess: Any, save_path: str) -> List[str]:
"""
Call save on all savers in the collection
:param sess: active session for session-based frameworks (e.g. TF)
:param save_path: path for saving checkpoints using savers. All saved file paths must
start with this path in their full path. For example if save_path is '/home/checkpoints/checkpoint-01',
then saved file paths can be '/home/checkpoints/checkpoint-01.main-network' but not
'/home/checkpoints/main-network'
:return: list of all saved paths
"""
paths = list()
for saver in self:
paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path)))
return paths
def restore(self, sess: Any, restore_path: str) -> None:
"""
Call restore on all savers in the collection
:param sess: active session for session-based frameworks (e.g. TF)
:param restore_path: path for restoring checkpoint using savers.
"""
for saver in self:
restore_path = "{}.{}".format(restore_path, saver.path)
saver.restore(sess, restore_path)
def __iter__(self):
"""
Return an iterator for savers in the collection
:return: saver iterator
"""
return (v for v in self._saver_dict.values())

View File

@@ -142,3 +142,42 @@ def test_hybrid_clip():
b = mx.nd.array((2,))
clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b)
assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all()
@pytest.mark.unit_test
def test_scoped_onxx_enable():
class Counter(object):
def __init__(self):
self._count = 0
def increment(self):
self._count += 1
@property
def count(self):
return self._count
class TempBlock(gluon.HybridBlock, OnnxHandlerBlock):
def __init__(self, counter: Counter):
super(TempBlock, self).__init__()
OnnxHandlerBlock.__init__(self)
self._counter = counter
def hybrid_forward(self, F, x, *args, **kwargs):
if self._onnx:
self._counter.increment()
return x
counter = Counter()
net = gluon.nn.HybridSequential()
for _ in range(10):
net.add(TempBlock(counter))
# ONNX disabled
net(nd.zeros((1,)))
assert counter.count == 0
# ONNX enabled
with ScopedOnnxEnable(net):
net(nd.zeros((1,)))
assert counter.count == 10

View File

@@ -0,0 +1,42 @@
import pytest
from rl_coach.saver import Saver, SaverCollection
@pytest.mark.unit_test
def test_checkpoint_collection():
class SaverTest(Saver):
def __init__(self, path):
self._path = path
self._count = 1
@property
def path(self):
return self._path
def merge(self, other: 'Saver'):
assert isinstance(other, SaverTest)
assert self.path == other.path
self._count += other._count
# test add
savers = SaverCollection(SaverTest('123'))
savers.add(SaverTest('123'))
savers.add(SaverTest('456'))
def check_collection(mul):
paths = ['123', '456']
for c in savers:
paths.remove(c.path)
if c.path == '123':
assert c._count == 2 * mul
elif c.path == '456':
assert c._count == 1 * mul
else:
assert False, "invalid path"
check_collection(1)
# test update
savers.update(savers)
check_collection(2)

View File

@@ -0,0 +1,21 @@
import pytest
from rl_coach import utils
@pytest.mark.unit_test
def test_get_checkpoint_state_default():
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext']
checkpoint_state = utils.get_checkpoint_state(files)
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
assert checkpoint_state.all_model_checkpoint_paths == [f[:-4] for f in sorted(files)]
@pytest.mark.unit_test
def test_get_checkpoint_state_custom():
files = ['prefix.4.test.ckpt.ext', 'prefix.2.test.ckpt.ext', 'prefix.3.test.ckpt.ext', 'prefix.1.test.ckpt.ext']
assert len(utils.get_checkpoint_state(files).all_model_checkpoint_paths) == 0 # doesn't match the default pattern
checkpoint_state = utils.get_checkpoint_state(files, filename_pattern=r'([0-9]+)[^0-9].*?\.ckpt')
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
assert checkpoint_state.all_model_checkpoint_paths == [f[7:-4] for f in sorted(files)]

View File

@@ -19,6 +19,7 @@ import importlib.util
import inspect
import json
import os
import re
import signal
import sys
import threading
@@ -26,7 +27,7 @@ import time
import traceback
from multiprocessing import Manager
from subprocess import Popen
from typing import List, Tuple
from typing import List, Tuple, Union
import atexit
import numpy as np
@@ -547,3 +548,50 @@ def indent_string(string):
return '\t' + string.replace('\n', '\n\t')
class CheckpointState(object):
"""
Helper class for checkpoint directory information. It replicates
the CheckpointState protobuf class in tensorflow.
"""
def __init__(self, checkpoints: List[str]):
self._checkpoints = checkpoints
@property
def all_model_checkpoint_paths(self):
return self._checkpoints
@property
def model_checkpoint_path(self):
return self._checkpoints[-1]
def __str__(self):
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
for c in self._checkpoints:
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
return out_str
def __repr__(self):
return str(self._checkpoints)
COACH_CHECKPOINT_PATTERN = r'\A([0-9]+)[^0-9].*?\.ckpt'
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
CheckpointState:
"""
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(1)) to sort
the checkpoint names and find the latest checkpoint
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
:param filename_pattern: regex pattern for checkpoint filenames
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names
"""
prog = re.compile(filename_pattern)
checkpoints = dict()
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
for name in filenames:
m = prog.search(name)
if m is not None and m.group(1) is not None:
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
checkpoints[int(m.group(1))] = full_path
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])