From 67eb9e4c28098d93ac122d65833c20b22b7e86c7 Mon Sep 17 00:00:00 2001 From: Sina Afrooze Date: Mon, 19 Nov 2018 09:45:49 -0800 Subject: [PATCH] Adding checkpointing framework (#74) * Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet --- rl_coach/agents/agent.py | 14 +++ rl_coach/agents/agent_interface.py | 10 ++ rl_coach/agents/composite_agent.py | 14 +++ rl_coach/architectures/architecture.py | 10 ++ .../mxnet_components/architecture.py | 40 ++++++- .../mxnet_components/general_network.py | 13 +- .../mxnet_components/heads/ppo_v_head.py | 2 +- .../mxnet_components/heads/v_head.py | 2 +- .../architectures/mxnet_components/savers.py | 113 ++++++++++++++++++ .../architectures/mxnet_components/utils.py | 45 ++++++- rl_coach/architectures/network_wrapper.py | 23 ++++ .../tensorflow_components/architecture.py | 11 ++ rl_coach/graph_managers/graph_manager.py | 55 ++++++--- rl_coach/level_manager.py | 11 ++ rl_coach/saver.py | 112 +++++++++++++++++ .../mxnet_components/test_utils.py | 39 ++++++ rl_coach/tests/test_saver.py | 42 +++++++ rl_coach/tests/test_utils.py | 21 ++++ rl_coach/utils.py | 50 +++++++- 19 files changed, 598 insertions(+), 29 deletions(-) create mode 100644 rl_coach/architectures/mxnet_components/savers.py create mode 100644 rl_coach/saver.py create mode 100644 rl_coach/tests/test_saver.py create mode 100644 rl_coach/tests/test_utils.py diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index c6e992e..a9ecbb2 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -30,6 +30,7 @@ from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, A from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse from rl_coach.logger import screen, Logger, EpisodeLogger from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplay +from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace from rl_coach.utils import Signal, force_list from rl_coach.utils import dynamic_import_and_instantiate_module_from_params @@ -996,3 +997,16 @@ class Agent(AgentInterface): def get_success_rate(self) -> float: return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed + + def collect_savers(self, parent_path_suffix: str) -> SaverCollection: + """ + Collect all of agent's network savers + :param parent_path_suffix: path suffix of the parent of the agent + (could be name of level manager or composite agent) + :return: collection of all agent savers + """ + parent_path_suffix = "{}.{}".format(parent_path_suffix, self.name) + savers = SaverCollection() + for network in self.networks.values(): + savers.update(network.collect_savers(parent_path_suffix)) + return savers diff --git a/rl_coach/agents/agent_interface.py b/rl_coach/agents/agent_interface.py index 968fa43..0a7aaab 100644 --- a/rl_coach/agents/agent_interface.py +++ b/rl_coach/agents/agent_interface.py @@ -19,6 +19,7 @@ from typing import Union, List, Dict import numpy as np from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition +from rl_coach.saver import SaverCollection class AgentInterface(object): @@ -153,3 +154,12 @@ class AgentInterface(object): :return: A tuple containing the actual action and additional info on the action """ raise NotImplementedError("") + + def collect_savers(self, parent_path_suffix: str) -> SaverCollection: + """ + Collect all of agent savers + :param parent_path_suffix: path suffix of the parent of the agent + (could be name of level manager or composite agent) + :return: collection of all agent savers + """ + raise NotImplementedError diff --git a/rl_coach/agents/composite_agent.py b/rl_coach/agents/composite_agent.py index 5f8f53f..42dcad9 100644 --- a/rl_coach/agents/composite_agent.py +++ b/rl_coach/agents/composite_agent.py @@ -25,6 +25,7 @@ from rl_coach.agents.agent_interface import AgentInterface from rl_coach.base_parameters import AgentParameters, VisualizationParameters from rl_coach.core_types import ActionInfo, EnvResponse, ActionType, RunPhase from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter +from rl_coach.saver import SaverCollection from rl_coach.spaces import ActionSpace from rl_coach.spaces import AgentSelection, AttentionActionSpace, SpacesDefinition from rl_coach.utils import short_dynamic_import @@ -412,3 +413,16 @@ class CompositeAgent(AgentInterface): :return: """ [agent.sync() for agent in self.agents.values()] + + def collect_savers(self, parent_path_suffix: str) -> SaverCollection: + """ + Collect all of agent's network savers + :param parent_path_suffix: path suffix of the parent of the agent + (could be name of level manager or composite agent) + :return: collection of all agent savers + """ + savers = SaverCollection() + for agent in self.agents.values(): + savers.update(agent.collect_savers( + parent_path_suffix="{}.{}".format(parent_path_suffix, self.name))) + return savers diff --git a/rl_coach/architectures/architecture.py b/rl_coach/architectures/architecture.py index 92b0f84..637eef6 100644 --- a/rl_coach/architectures/architecture.py +++ b/rl_coach/architectures/architecture.py @@ -19,6 +19,7 @@ from typing import Any, Dict, List, Tuple import numpy as np from rl_coach.base_parameters import AgentParameters +from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition @@ -213,3 +214,12 @@ class Architecture(object): :param placeholder: a placeholder for binding the value to assign_op. """ raise NotImplementedError + + def collect_savers(self, parent_path_suffix: str) -> SaverCollection: + """ + Collection of all savers for the network (typically only one saver for network and one for ONNX export) + :param parent_path_suffix: path suffix of the parent of the network + (e.g. could be name of level manager plus name of agent) + :return: saver collection for the network + """ + raise NotImplementedError diff --git a/rl_coach/architectures/mxnet_components/architecture.py b/rl_coach/architectures/mxnet_components/architecture.py index 8b39851..8f4c6a1 100644 --- a/rl_coach/architectures/mxnet_components/architecture.py +++ b/rl_coach/architectures/mxnet_components/architecture.py @@ -24,7 +24,9 @@ from mxnet.ndarray import NDArray from rl_coach.architectures.architecture import Architecture from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION from rl_coach.architectures.mxnet_components import utils -from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters +from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver +from rl_coach.base_parameters import AgentParameters +from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition from rl_coach.utils import force_list, squeeze_list @@ -81,17 +83,25 @@ class MxnetArchitecture(Architecture): """ return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null') + def _model_input_shapes(self) -> List[List[int]]: + """ + Create a list of input array shapes + :return: type of input shapes + """ + allowed_inputs = copy.copy(self.spaces.state.sub_spaces) + allowed_inputs["action"] = copy.copy(self.spaces.action) + allowed_inputs["goal"] = copy.copy(self.spaces.goal) + embedders = self.model.nets[0].input_embedders + return list([1] + allowed_inputs[emb.embedder_name].shape.tolist() for emb in embedders) + def _dummy_model_inputs(self) -> Tuple[NDArray, ...]: """ Creates a tuple of input arrays with correct shapes that can be used for shape inference of the model weights and for printing the summary :return: tuple of inputs for model forward pass """ - allowed_inputs = copy.copy(self.spaces.state.sub_spaces) - allowed_inputs["action"] = copy.copy(self.spaces.action) - allowed_inputs["goal"] = copy.copy(self.spaces.goal) - embedders = self.model.nets[0].input_embedders - inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders) + input_shapes = self._model_input_shapes() + inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes) return inputs def construct_model(self) -> None: @@ -402,3 +412,21 @@ class MxnetArchitecture(Architecture): :return: None """ assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported' + + def collect_savers(self, parent_path_suffix: str) -> SaverCollection: + """ + Collection of all checkpoints for the network (typically only one checkpoint) + :param parent_path_suffix: path suffix of the parent of the network + (e.g. could be name of level manager plus name of agent) + :return: checkpoint collection for the network + """ + name = self.name.replace('/', '.') + savers = SaverCollection(ParameterDictSaver( + name="{}.{}".format(parent_path_suffix, name), + param_dict=self.model.collect_params())) + if self.ap.task_parameters.export_onnx_graph: + savers.add(OnnxSaver( + name="{}.{}.onnx".format(parent_path_suffix, name), + model=self.model, + input_shapes=self._model_input_shapes())) + return savers diff --git a/rl_coach/architectures/mxnet_components/general_network.py b/rl_coach/architectures/mxnet_components/general_network.py index 8a856f3..5ded582 100644 --- a/rl_coach/architectures/mxnet_components/general_network.py +++ b/rl_coach/architectures/mxnet_components/general_network.py @@ -279,7 +279,7 @@ def _get_output_head( return module -class ScaledGradHead(HybridBlock): +class ScaledGradHead(HybridBlock, utils.OnnxHandlerBlock): """ Wrapper block for applying gradient scaling to input before feeding the head network """ @@ -292,7 +292,7 @@ class ScaledGradHead(HybridBlock): agent_params: AgentParameters, head_params: HeadParameters) -> None: """ - :param head_idx: the head index + :param head_index: the head index :param head_type_index: the head type index (same index if head_param.num_output_head_copies>0) :param network_name: name of the network :param spaces: state and action space definitions @@ -301,6 +301,7 @@ class ScaledGradHead(HybridBlock): :param head_params: head parameters """ super(ScaledGradHead, self).__init__() + utils.OnnxHandlerBlock.__init__(self) head_params = _sanitize_activation(head_params) with self.name_scope(): @@ -330,7 +331,13 @@ class ScaledGradHead(HybridBlock): :param gradient_rescaler: gradient rescaler for partial blocking of gradient :return: head output """ - grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x) + if self._onnx: + # ONNX doesn't support BlockGrad() operator, but it's not typically needed for + # ONNX because mostly forward calls are performed using ONNX exported network. + grad_scaled_x = x + else: + grad_scaled_x = (F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + + F.broadcast_mul(gradient_rescaler, x)) out = self.head(grad_scaled_x) return out diff --git a/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py b/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py index 7b675e4..a353cc9 100644 --- a/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py +++ b/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py @@ -113,7 +113,7 @@ class PPOVHead(Head): :param x: middleware state representation, of shape (batch_size, in_channels). :return: final value output of network, of shape (batch_size). """ - return self.dense(x).squeeze() + return self.dense(x).squeeze(axis=1) def loss(self) -> mx.gluon.loss.Loss: """ diff --git a/rl_coach/architectures/mxnet_components/heads/v_head.py b/rl_coach/architectures/mxnet_components/heads/v_head.py index cfa765e..35107c4 100644 --- a/rl_coach/architectures/mxnet_components/heads/v_head.py +++ b/rl_coach/architectures/mxnet_components/heads/v_head.py @@ -98,4 +98,4 @@ class VHead(Head): :param x: middleware state representation, of shape (batch_size, in_channels). :return: final output of value network, of shape (batch_size). """ - return self.dense(x).squeeze() + return self.dense(x).squeeze(axis=1) diff --git a/rl_coach/architectures/mxnet_components/savers.py b/rl_coach/architectures/mxnet_components/savers.py new file mode 100644 index 0000000..74f3895 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/savers.py @@ -0,0 +1,113 @@ +from typing import Any, List, Tuple + +from mxnet import gluon, sym +from mxnet.contrib import onnx as onnx_mxnet +import numpy as np + +from rl_coach.architectures.mxnet_components.utils import ScopedOnnxEnable +from rl_coach.saver import Saver + + +class ParameterDictSaver(Saver): + """ + Child class that implements saver for mxnet gluon parameter dictionary + """ + def __init__(self, name: str, param_dict: gluon.ParameterDict): + self._name = name + self._param_dict = param_dict + + @property + def path(self): + """ + Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able. + """ + return self._name + + def save(self, sess: None, save_path: str) -> List[str]: + """ + Save to save_path + :param sess: active session for session-based frameworks (e.g. TF) + :param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count). + :return: list of all saved paths + """ + assert sess is None + self._param_dict.save(save_path) + return [save_path] + + def restore(self, sess: Any, restore_path: str): + """ + Restore from restore_path + :param sess: active session for session-based frameworks (e.g. TF) + :param restore_path: full path to load checkpoint from. + """ + assert sess is None + self._param_dict.load(restore_path) + + def merge(self, other: 'Saver'): + """ + Merge other saver into this saver + :param other: saver to be merged into self + """ + if not isinstance(other, ParameterDictSaver): + raise TypeError('merging only supported with ParameterDictSaver (type:{})'.format(type(other))) + self._param_dict.update(other._param_dict) + + +class OnnxSaver(Saver): + """ + Child class that implements saver for exporting gluon HybridBlock to ONNX + """ + def __init__(self, name: str, model: gluon.HybridBlock, input_shapes: List[List[int]]): + self._name = name + self._sym = self._get_onnx_sym(model, len(input_shapes)) + self._param_dict = model.collect_params() + self._input_shapes = input_shapes + + @staticmethod + def _get_onnx_sym(model: gluon.HybridBlock, num_inputs: int) -> sym.Symbol: + """ + Returns a symbolic graph for the model + :param model: gluon HybridBlock that constructs the symbolic graph + :param num_inputs: number of inputs to the graph + :return: symbol for the network + """ + var_args = [sym.Variable('Data{}'.format(i)) for i in range(num_inputs)] + with ScopedOnnxEnable(model): + return sym.Group(gluon.block._flatten(model(*var_args), "output")[0]) + + @property + def path(self): + """ + Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able. + """ + return self._name + + def save(self, sess: None, save_path: str) -> List[str]: + """ + Save to save_path + :param sess: active session for session-based frameworks (e.g. TF). Must be None. + :param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count). + :return: list of all saved paths + """ + assert sess is None + params = {name:param._reduce() for name, param in self._param_dict.items()} + export_path = onnx_mxnet.export_model(self._sym, params, self._input_shapes, np.float32, save_path) + + return [export_path] + + def restore(self, sess: Any, restore_path: str): + """ + Restore from restore_path + :param sess: active session for session-based frameworks (e.g. TF) + :param restore_path: full path to load checkpoint from. + """ + assert sess is None + # Nothing to restore for ONNX + + def merge(self, other: 'Saver'): + """ + Merge other saver into this saver + :param other: saver to be merged into self + """ + # No merging is supported for ONNX. self.path must be unique + raise RuntimeError('merging not supported for ONNX exporter') \ No newline at end of file diff --git a/rl_coach/architectures/mxnet_components/utils.py b/rl_coach/architectures/mxnet_components/utils.py index 5f1659c..cfd497f 100644 --- a/rl_coach/architectures/mxnet_components/utils.py +++ b/rl_coach/architectures/mxnet_components/utils.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, Iterable, List, Tuple, Union from types import ModuleType import mxnet as mx -from mxnet import nd +from mxnet import gluon, nd from mxnet.ndarray import NDArray import numpy as np @@ -278,3 +278,46 @@ def get_mxnet_activation_name(activation_name: str): "Activation function must be one of the following {}. instead it was: {}".format( activation_functions.keys(), activation_name) return activation_functions[activation_name] + + +class OnnxHandlerBlock(object): + """ + Helper base class for gluon blocks that must behave differently for ONNX export forward pass + """ + def __init__(self): + self._onnx = False + + def enable_onnx(self): + self._onnx = True + + def disable_onnx(self): + self._onnx = False + + +class ScopedOnnxEnable(object): + """ + Helper scoped ONNX enable class + """ + def __init__(self, net: gluon.HybridBlock): + self._onnx_handlers = self._get_onnx_handlers(net) + + def __enter__(self): + for b in self._onnx_handlers: + b.enable_onnx() + + def __exit__(self, exc_type, exc_val, exc_tb): + for b in self._onnx_handlers: + b.disable_onnx() + + @staticmethod + def _get_onnx_handlers(block: gluon.HybridBlock) -> List[OnnxHandlerBlock]: + """ + Iterates through all child blocks and return all of them that are instance of OnnxHandlerBlock + :return: list of OnnxHandlerBlock child blocks + """ + handlers = list() + if isinstance(block, OnnxHandlerBlock): + handlers.append(block) + for child_block in block._children.values(): + handlers += ScopedOnnxEnable._get_onnx_handlers(child_block) + return handlers diff --git a/rl_coach/architectures/network_wrapper.py b/rl_coach/architectures/network_wrapper.py index 9122190..61a3d14 100644 --- a/rl_coach/architectures/network_wrapper.py +++ b/rl_coach/architectures/network_wrapper.py @@ -18,6 +18,7 @@ from typing import List, Tuple from rl_coach.base_parameters import Frameworks, AgentParameters from rl_coach.logger import failed_imports +from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition try: import tensorflow as tf @@ -251,3 +252,25 @@ class NetworkWrapper(object): result.append(str(self.online_network)) result.append("") return '\n'.join(result) + + def collect_savers(self, parent_path_suffix: str) -> SaverCollection: + """ + Collect all of network's savers for global or online network + Note: global, online, and target network are all copies fo the same network which parameters that are + updated at different rates. So we only need to save one of the networks; the one that holds the most + recent parameters. target network is created for some agents and used for stabilizing training by + updating parameters from online network at a slower rate. As a result, target network never contains + the most recent set of parameters. In single-worker training, no global network is created and online + network contains the most recent parameters. In vertical distributed training with more than one worker, + global network is updated by all workers and contains the most recent parameters. + Therefore preference is given to global network if it exists, otherwise online network is used + for saving. + :param parent_path_suffix: path suffix of the parent of the network wrapper + (e.g. could be name of level manager plus name of agent) + :return: collection of all checkpoint objects + """ + if self.global_network: + savers = self.global_network.collect_savers(parent_path_suffix) + else: + savers = self.online_network.collect_savers(parent_path_suffix) + return savers diff --git a/rl_coach/architectures/tensorflow_components/architecture.py b/rl_coach/architectures/tensorflow_components/architecture.py index 8e7deae..03ad98c 100644 --- a/rl_coach/architectures/tensorflow_components/architecture.py +++ b/rl_coach/architectures/tensorflow_components/architecture.py @@ -23,6 +23,7 @@ import tensorflow as tf from rl_coach.architectures.architecture import Architecture from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters from rl_coach.core_types import GradientClippingMethod +from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait @@ -637,6 +638,16 @@ class TensorFlowArchitecture(Architecture): self.curr_rnn_c_in = self.middleware.c_init self.curr_rnn_h_in = self.middleware.h_init + def collect_savers(self, parent_path_suffix: str) -> SaverCollection: + """ + Collection of all checkpoints for the network (typically only one checkpoint) + :param parent_path_suffix: path suffix of the parent of the network + (e.g. could be name of level manager plus name of agent) + :return: checkpoint collection for the network + """ + # TODO implement returning checkpoints for tensorflow + return SaverCollection() + def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None: """ diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 94a1250..b68f55c 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -31,7 +31,8 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T from rl_coach.environments.environment import Environment from rl_coach.level_manager import LevelManager from rl_coach.logger import screen, Logger -from rl_coach.utils import set_cpu, start_shell_command_and_wait +from rl_coach.saver import SaverCollection +from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.data_stores.data_store import SyncFiles @@ -87,7 +88,7 @@ class GraphManager(object): schedule_params: ScheduleParameters, vis_params: VisualizationParameters = VisualizationParameters()): self.sess = None - self.level_managers = [] + self.level_managers = [] # type: List[LevelManager] self.top_level_manager = None self.environments = [] self.heatup_steps = schedule_params.heatup_steps @@ -248,12 +249,22 @@ class GraphManager(object): if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir: self.save_graph() + def _create_session_mx(self): + """ + Call set_session to initialize parameters and construct checkpoint_saver + """ + self.set_session(sess=None) # Initialize all modules + self.checkpoint_saver = SaverCollection() + for level in self.level_managers: + self.checkpoint_saver.update(level.collect_savers()) + # restore from checkpoint if given + self.restore_checkpoint() + def create_session(self, task_parameters: TaskParameters): if task_parameters.framework_type == Frameworks.tensorflow: self._create_session_tf(task_parameters) elif task_parameters.framework_type == Frameworks.mxnet: - self.set_session(sess=None) # Initialize all modules - # TODO add checkpoint loading + self._create_session_mx() else: raise ValueError('Invalid framework {}'.format(task_parameters.framework_type)) @@ -270,14 +281,13 @@ class GraphManager(object): name='graphdef.pb', as_text=False) - def save_onnx_graph(self) -> None: + def _save_onnx_graph_tf(self) -> None: """ - Save the graph as an ONNX graph. + Save the tensorflow graph as an ONNX graph. This requires the graph and the weights checkpoint to be stored in the experiment directory. It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX. :return: None """ - # collect input and output nodes input_nodes = [] output_nodes = [] @@ -290,11 +300,20 @@ class GraphManager(object): for output in network.online_network.outputs: output_nodes.append(output.name) - # TODO: make this framework agnostic from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir) + def save_onnx_graph(self) -> None: + """ + Save the graph as an ONNX graph. + This requires the graph and the weights checkpoint to be stored in the experiment directory. + It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX. + :return: None + """ + if self.task_parameters.framework_type == Frameworks.tensorflow: + self._save_onnx_graph_tf() + def setup_logger(self) -> None: # dump documentation logger_prefix = "{graph_name}".format(graph_name=self.name) @@ -526,14 +545,13 @@ class GraphManager(object): if self.evaluate(self.evaluation_steps): break - def _restore_checkpoint_tf(self, checkpoint_dir: str): + def _restore_checkpoint_tf(self, checkpoint_path: str): import tensorflow as tf - checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) - screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) variables = {} - for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): + reader = tf.contrib.framework.load_checkpoint(checkpoint_path) + for var_name, _ in reader.get_variable_to_shape_map().items(): # Load the variable - var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) + var = reader.get_tensor(var_name) # Set the new name new_name = var_name @@ -548,11 +566,14 @@ class GraphManager(object): # TODO: find better way to load checkpoints that were saved with a global network into the online network if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir: + + checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) + screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) + if self.task_parameters.framework_type == Frameworks.tensorflow: - self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir) + self._restore_checkpoint_tf(checkpoint.model_checkpoint_path) elif self.task_parameters.framework_type == Frameworks.mxnet: - # TODO implement checkpoint restore - pass + self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) else: raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type)) @@ -572,6 +593,8 @@ class GraphManager(object): "{}_Step-{}.ckpt".format( self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])) + if not os.path.exists(os.path.dirname(checkpoint_path)): + os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure if not isinstance(self.task_parameters, DistributedTaskParameters): if self.checkpoint_saver is not None: saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index 3df2766..312a5be 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -20,6 +20,7 @@ from rl_coach.agents.composite_agent import CompositeAgent from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps, Transition from rl_coach.environments.environment import Environment from rl_coach.environments.environment_interface import EnvironmentInterface +from rl_coach.saver import SaverCollection from rl_coach.spaces import ActionSpace, SpacesDefinition @@ -292,3 +293,13 @@ class LevelManager(EnvironmentInterface): def should_stop(self) -> bool: return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()]) + + def collect_savers(self) -> SaverCollection: + """ + Calls collect_savers() on all agents and combines the results to a single collection + :return: saver collection of all agent savers + """ + savers = SaverCollection() + for agent in self.agents.values(): + savers.update(agent.collect_savers(parent_path_suffix=self.name)) + return savers diff --git a/rl_coach/saver.py b/rl_coach/saver.py new file mode 100644 index 0000000..ce22c1f --- /dev/null +++ b/rl_coach/saver.py @@ -0,0 +1,112 @@ +""" +Module for abstract base class for checkpoint object and checkpoint collection +""" +from typing import Any, Dict, List + + +class Saver(object): + """ + ABC for saver objects that implement saving/restoring to/from path, and merging two savers. + """ + @property + def path(self): + """ + Relative path for save/load. If two saver objects return the same path, they must be merge-able. + """ + raise NotImplementedError + + def save(self, sess: Any, save_path: str) -> List[str]: + """ + Save to save_path + :param sess: active session for session-based frameworks (e.g. TF) + :param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count). + :return: list of all saved paths + """ + raise NotImplementedError + + def restore(self, sess: Any, restore_path: str) -> None: + """ + Restore from restore_path + :param sess: active session for session-based frameworks (e.g. TF) + :param restore_path: full path to load checkpoint from. + """ + raise NotImplementedError + + def merge(self, other: 'Saver') -> None: + """ + Merge other saver into this saver + :param other: saver to be merged into self + """ + raise NotImplementedError + + +class SaverCollection(object): + """ + Object for storing a collection of saver objects. It takes care of ensuring uniqueness of saver paths + and merging savers if they have the same path. For example, if a saver handles saving a generic key/value + file for all networks in a single file, it can use a more generic path and all savers of all networks would be + merged into a single saver that saves/restores parameters for all networks. + NOTE: If two savers have the same path, the respective saver class must support merging them + into a single saver that saves/restores all merged parameters. + """ + def __init__(self, saver: Saver = None): + """ + :param saver: optional initial saver for the collection + """ + self._saver_dict = dict() # type: Dict[str, Saver] + if saver is not None: + self._saver_dict[saver.path] = saver + + def add(self, saver: Saver): + """ + Add a new saver to the collection. If saver.path is already in the collection, merge + the new saver with the existing saver. + :param saver: new saver to be added to collection + """ + if saver.path in self._saver_dict: + self._saver_dict[saver.path].merge(saver) + else: + self._saver_dict[saver.path] = saver + + def update(self, other: 'SaverCollection'): + """ + Merge savers from other collection into self + :param other: saver collection to update self with. + """ + for c in other: + self.add(c) + + def save(self, sess: Any, save_path: str) -> List[str]: + """ + Call save on all savers in the collection + :param sess: active session for session-based frameworks (e.g. TF) + :param save_path: path for saving checkpoints using savers. All saved file paths must + start with this path in their full path. For example if save_path is '/home/checkpoints/checkpoint-01', + then saved file paths can be '/home/checkpoints/checkpoint-01.main-network' but not + '/home/checkpoints/main-network' + :return: list of all saved paths + """ + paths = list() + for saver in self: + paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path))) + return paths + + def restore(self, sess: Any, restore_path: str) -> None: + """ + Call restore on all savers in the collection + :param sess: active session for session-based frameworks (e.g. TF) + :param restore_path: path for restoring checkpoint using savers. + """ + for saver in self: + restore_path = "{}.{}".format(restore_path, saver.path) + saver.restore(sess, restore_path) + + def __iter__(self): + """ + Return an iterator for savers in the collection + :return: saver iterator + """ + return (v for v in self._saver_dict.values()) + + + diff --git a/rl_coach/tests/architectures/mxnet_components/test_utils.py b/rl_coach/tests/architectures/mxnet_components/test_utils.py index 0af729a..2765998 100644 --- a/rl_coach/tests/architectures/mxnet_components/test_utils.py +++ b/rl_coach/tests/architectures/mxnet_components/test_utils.py @@ -142,3 +142,42 @@ def test_hybrid_clip(): b = mx.nd.array((2,)) clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b) assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all() + + +@pytest.mark.unit_test +def test_scoped_onxx_enable(): + class Counter(object): + def __init__(self): + self._count = 0 + + def increment(self): + self._count += 1 + + @property + def count(self): + return self._count + + class TempBlock(gluon.HybridBlock, OnnxHandlerBlock): + def __init__(self, counter: Counter): + super(TempBlock, self).__init__() + OnnxHandlerBlock.__init__(self) + self._counter = counter + + def hybrid_forward(self, F, x, *args, **kwargs): + if self._onnx: + self._counter.increment() + return x + + counter = Counter() + net = gluon.nn.HybridSequential() + for _ in range(10): + net.add(TempBlock(counter)) + + # ONNX disabled + net(nd.zeros((1,))) + assert counter.count == 0 + + # ONNX enabled + with ScopedOnnxEnable(net): + net(nd.zeros((1,))) + assert counter.count == 10 diff --git a/rl_coach/tests/test_saver.py b/rl_coach/tests/test_saver.py new file mode 100644 index 0000000..c02e6ef --- /dev/null +++ b/rl_coach/tests/test_saver.py @@ -0,0 +1,42 @@ +import pytest + +from rl_coach.saver import Saver, SaverCollection + + +@pytest.mark.unit_test +def test_checkpoint_collection(): + class SaverTest(Saver): + def __init__(self, path): + self._path = path + self._count = 1 + + @property + def path(self): + return self._path + + def merge(self, other: 'Saver'): + assert isinstance(other, SaverTest) + assert self.path == other.path + self._count += other._count + + # test add + savers = SaverCollection(SaverTest('123')) + savers.add(SaverTest('123')) + savers.add(SaverTest('456')) + + def check_collection(mul): + paths = ['123', '456'] + for c in savers: + paths.remove(c.path) + if c.path == '123': + assert c._count == 2 * mul + elif c.path == '456': + assert c._count == 1 * mul + else: + assert False, "invalid path" + + check_collection(1) + + # test update + savers.update(savers) + check_collection(2) diff --git a/rl_coach/tests/test_utils.py b/rl_coach/tests/test_utils.py new file mode 100644 index 0000000..3ffd686 --- /dev/null +++ b/rl_coach/tests/test_utils.py @@ -0,0 +1,21 @@ +import pytest + +from rl_coach import utils + + +@pytest.mark.unit_test +def test_get_checkpoint_state_default(): + files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext'] + checkpoint_state = utils.get_checkpoint_state(files) + assert checkpoint_state.model_checkpoint_path == '4.test.ckpt' + assert checkpoint_state.all_model_checkpoint_paths == [f[:-4] for f in sorted(files)] + + +@pytest.mark.unit_test +def test_get_checkpoint_state_custom(): + files = ['prefix.4.test.ckpt.ext', 'prefix.2.test.ckpt.ext', 'prefix.3.test.ckpt.ext', 'prefix.1.test.ckpt.ext'] + assert len(utils.get_checkpoint_state(files).all_model_checkpoint_paths) == 0 # doesn't match the default pattern + checkpoint_state = utils.get_checkpoint_state(files, filename_pattern=r'([0-9]+)[^0-9].*?\.ckpt') + assert checkpoint_state.model_checkpoint_path == '4.test.ckpt' + assert checkpoint_state.all_model_checkpoint_paths == [f[7:-4] for f in sorted(files)] + diff --git a/rl_coach/utils.py b/rl_coach/utils.py index 61808c8..4a4d248 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -19,6 +19,7 @@ import importlib.util import inspect import json import os +import re import signal import sys import threading @@ -26,7 +27,7 @@ import time import traceback from multiprocessing import Manager from subprocess import Popen -from typing import List, Tuple +from typing import List, Tuple, Union import atexit import numpy as np @@ -547,3 +548,50 @@ def indent_string(string): return '\t' + string.replace('\n', '\n\t') +class CheckpointState(object): + """ + Helper class for checkpoint directory information. It replicates + the CheckpointState protobuf class in tensorflow. + """ + def __init__(self, checkpoints: List[str]): + self._checkpoints = checkpoints + + @property + def all_model_checkpoint_paths(self): + return self._checkpoints + + @property + def model_checkpoint_path(self): + return self._checkpoints[-1] + + def __str__(self): + out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path) + for c in self._checkpoints: + out_str += 'all_model_checkpoint_paths: {}\n'.format(c) + return out_str + + def __repr__(self): + return str(self._checkpoints) + + +COACH_CHECKPOINT_PATTERN = r'\A([0-9]+)[^0-9].*?\.ckpt' + + +def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\ + CheckpointState: + """ + Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(1)) to sort + the checkpoint names and find the latest checkpoint + :param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory + :param filename_pattern: regex pattern for checkpoint filenames + :return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names + """ + prog = re.compile(filename_pattern) + checkpoints = dict() + filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir + for name in filenames: + m = prog.search(name) + if m is not None and m.group(1) is not None: + full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0) + checkpoints[int(m.group(1))] = full_path + return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])