mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -30,6 +30,7 @@ from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, A
|
||||
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
|
||||
from rl_coach.logger import screen, Logger, EpisodeLogger
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplay
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
|
||||
from rl_coach.utils import Signal, force_list
|
||||
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
||||
@@ -996,3 +997,16 @@ class Agent(AgentInterface):
|
||||
|
||||
def get_success_rate(self) -> float:
|
||||
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collect all of agent's network savers
|
||||
:param parent_path_suffix: path suffix of the parent of the agent
|
||||
(could be name of level manager or composite agent)
|
||||
:return: collection of all agent savers
|
||||
"""
|
||||
parent_path_suffix = "{}.{}".format(parent_path_suffix, self.name)
|
||||
savers = SaverCollection()
|
||||
for network in self.networks.values():
|
||||
savers.update(network.collect_savers(parent_path_suffix))
|
||||
return savers
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Union, List, Dict
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition
|
||||
from rl_coach.saver import SaverCollection
|
||||
|
||||
|
||||
class AgentInterface(object):
|
||||
@@ -153,3 +154,12 @@ class AgentInterface(object):
|
||||
:return: A tuple containing the actual action and additional info on the action
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collect all of agent savers
|
||||
:param parent_path_suffix: path suffix of the parent of the agent
|
||||
(could be name of level manager or composite agent)
|
||||
:return: collection of all agent savers
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -25,6 +25,7 @@ from rl_coach.agents.agent_interface import AgentInterface
|
||||
from rl_coach.base_parameters import AgentParameters, VisualizationParameters
|
||||
from rl_coach.core_types import ActionInfo, EnvResponse, ActionType, RunPhase
|
||||
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import ActionSpace
|
||||
from rl_coach.spaces import AgentSelection, AttentionActionSpace, SpacesDefinition
|
||||
from rl_coach.utils import short_dynamic_import
|
||||
@@ -412,3 +413,16 @@ class CompositeAgent(AgentInterface):
|
||||
:return:
|
||||
"""
|
||||
[agent.sync() for agent in self.agents.values()]
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collect all of agent's network savers
|
||||
:param parent_path_suffix: path suffix of the parent of the agent
|
||||
(could be name of level manager or composite agent)
|
||||
:return: collection of all agent savers
|
||||
"""
|
||||
savers = SaverCollection()
|
||||
for agent in self.agents.values():
|
||||
savers.update(agent.collect_savers(
|
||||
parent_path_suffix="{}.{}".format(parent_path_suffix, self.name)))
|
||||
return savers
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
@@ -213,3 +214,12 @@ class Architecture(object):
|
||||
:param placeholder: a placeholder for binding the value to assign_op.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collection of all savers for the network (typically only one saver for network and one for ONNX export)
|
||||
:param parent_path_suffix: path suffix of the parent of the network
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: saver collection for the network
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -24,7 +24,9 @@ from mxnet.ndarray import NDArray
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list
|
||||
|
||||
@@ -81,17 +83,25 @@ class MxnetArchitecture(Architecture):
|
||||
"""
|
||||
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||
|
||||
def _model_input_shapes(self) -> List[List[int]]:
|
||||
"""
|
||||
Create a list of input array shapes
|
||||
:return: type of input shapes
|
||||
"""
|
||||
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
||||
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
||||
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
||||
embedders = self.model.nets[0].input_embedders
|
||||
return list([1] + allowed_inputs[emb.embedder_name].shape.tolist() for emb in embedders)
|
||||
|
||||
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
|
||||
"""
|
||||
Creates a tuple of input arrays with correct shapes that can be used for shape inference
|
||||
of the model weights and for printing the summary
|
||||
:return: tuple of inputs for model forward pass
|
||||
"""
|
||||
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
||||
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
||||
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
||||
embedders = self.model.nets[0].input_embedders
|
||||
inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders)
|
||||
input_shapes = self._model_input_shapes()
|
||||
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
|
||||
return inputs
|
||||
|
||||
def construct_model(self) -> None:
|
||||
@@ -402,3 +412,21 @@ class MxnetArchitecture(Architecture):
|
||||
:return: None
|
||||
"""
|
||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collection of all checkpoints for the network (typically only one checkpoint)
|
||||
:param parent_path_suffix: path suffix of the parent of the network
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: checkpoint collection for the network
|
||||
"""
|
||||
name = self.name.replace('/', '.')
|
||||
savers = SaverCollection(ParameterDictSaver(
|
||||
name="{}.{}".format(parent_path_suffix, name),
|
||||
param_dict=self.model.collect_params()))
|
||||
if self.ap.task_parameters.export_onnx_graph:
|
||||
savers.add(OnnxSaver(
|
||||
name="{}.{}.onnx".format(parent_path_suffix, name),
|
||||
model=self.model,
|
||||
input_shapes=self._model_input_shapes()))
|
||||
return savers
|
||||
|
||||
@@ -279,7 +279,7 @@ def _get_output_head(
|
||||
return module
|
||||
|
||||
|
||||
class ScaledGradHead(HybridBlock):
|
||||
class ScaledGradHead(HybridBlock, utils.OnnxHandlerBlock):
|
||||
"""
|
||||
Wrapper block for applying gradient scaling to input before feeding the head network
|
||||
"""
|
||||
@@ -292,7 +292,7 @@ class ScaledGradHead(HybridBlock):
|
||||
agent_params: AgentParameters,
|
||||
head_params: HeadParameters) -> None:
|
||||
"""
|
||||
:param head_idx: the head index
|
||||
:param head_index: the head index
|
||||
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
|
||||
:param network_name: name of the network
|
||||
:param spaces: state and action space definitions
|
||||
@@ -301,6 +301,7 @@ class ScaledGradHead(HybridBlock):
|
||||
:param head_params: head parameters
|
||||
"""
|
||||
super(ScaledGradHead, self).__init__()
|
||||
utils.OnnxHandlerBlock.__init__(self)
|
||||
|
||||
head_params = _sanitize_activation(head_params)
|
||||
with self.name_scope():
|
||||
@@ -330,7 +331,13 @@ class ScaledGradHead(HybridBlock):
|
||||
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
|
||||
:return: head output
|
||||
"""
|
||||
grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x)
|
||||
if self._onnx:
|
||||
# ONNX doesn't support BlockGrad() operator, but it's not typically needed for
|
||||
# ONNX because mostly forward calls are performed using ONNX exported network.
|
||||
grad_scaled_x = x
|
||||
else:
|
||||
grad_scaled_x = (F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) +
|
||||
F.broadcast_mul(gradient_rescaler, x))
|
||||
out = self.head(grad_scaled_x)
|
||||
return out
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ class PPOVHead(Head):
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final value output of network, of shape (batch_size).
|
||||
"""
|
||||
return self.dense(x).squeeze()
|
||||
return self.dense(x).squeeze(axis=1)
|
||||
|
||||
def loss(self) -> mx.gluon.loss.Loss:
|
||||
"""
|
||||
|
||||
@@ -98,4 +98,4 @@ class VHead(Head):
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final output of value network, of shape (batch_size).
|
||||
"""
|
||||
return self.dense(x).squeeze()
|
||||
return self.dense(x).squeeze(axis=1)
|
||||
|
||||
113
rl_coach/architectures/mxnet_components/savers.py
Normal file
113
rl_coach/architectures/mxnet_components/savers.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from mxnet import gluon, sym
|
||||
from mxnet.contrib import onnx as onnx_mxnet
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.architectures.mxnet_components.utils import ScopedOnnxEnable
|
||||
from rl_coach.saver import Saver
|
||||
|
||||
|
||||
class ParameterDictSaver(Saver):
|
||||
"""
|
||||
Child class that implements saver for mxnet gluon parameter dictionary
|
||||
"""
|
||||
def __init__(self, name: str, param_dict: gluon.ParameterDict):
|
||||
self._name = name
|
||||
self._param_dict = param_dict
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
def save(self, sess: None, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
assert sess is None
|
||||
self._param_dict.save(save_path)
|
||||
return [save_path]
|
||||
|
||||
def restore(self, sess: Any, restore_path: str):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
assert sess is None
|
||||
self._param_dict.load(restore_path)
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
if not isinstance(other, ParameterDictSaver):
|
||||
raise TypeError('merging only supported with ParameterDictSaver (type:{})'.format(type(other)))
|
||||
self._param_dict.update(other._param_dict)
|
||||
|
||||
|
||||
class OnnxSaver(Saver):
|
||||
"""
|
||||
Child class that implements saver for exporting gluon HybridBlock to ONNX
|
||||
"""
|
||||
def __init__(self, name: str, model: gluon.HybridBlock, input_shapes: List[List[int]]):
|
||||
self._name = name
|
||||
self._sym = self._get_onnx_sym(model, len(input_shapes))
|
||||
self._param_dict = model.collect_params()
|
||||
self._input_shapes = input_shapes
|
||||
|
||||
@staticmethod
|
||||
def _get_onnx_sym(model: gluon.HybridBlock, num_inputs: int) -> sym.Symbol:
|
||||
"""
|
||||
Returns a symbolic graph for the model
|
||||
:param model: gluon HybridBlock that constructs the symbolic graph
|
||||
:param num_inputs: number of inputs to the graph
|
||||
:return: symbol for the network
|
||||
"""
|
||||
var_args = [sym.Variable('Data{}'.format(i)) for i in range(num_inputs)]
|
||||
with ScopedOnnxEnable(model):
|
||||
return sym.Group(gluon.block._flatten(model(*var_args), "output")[0])
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
def save(self, sess: None, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF). Must be None.
|
||||
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
assert sess is None
|
||||
params = {name:param._reduce() for name, param in self._param_dict.items()}
|
||||
export_path = onnx_mxnet.export_model(self._sym, params, self._input_shapes, np.float32, save_path)
|
||||
|
||||
return [export_path]
|
||||
|
||||
def restore(self, sess: Any, restore_path: str):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
assert sess is None
|
||||
# Nothing to restore for ONNX
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
# No merging is supported for ONNX. self.path must be unique
|
||||
raise RuntimeError('merging not supported for ONNX exporter')
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet import nd
|
||||
from mxnet import gluon, nd
|
||||
from mxnet.ndarray import NDArray
|
||||
import numpy as np
|
||||
|
||||
@@ -278,3 +278,46 @@ def get_mxnet_activation_name(activation_name: str):
|
||||
"Activation function must be one of the following {}. instead it was: {}".format(
|
||||
activation_functions.keys(), activation_name)
|
||||
return activation_functions[activation_name]
|
||||
|
||||
|
||||
class OnnxHandlerBlock(object):
|
||||
"""
|
||||
Helper base class for gluon blocks that must behave differently for ONNX export forward pass
|
||||
"""
|
||||
def __init__(self):
|
||||
self._onnx = False
|
||||
|
||||
def enable_onnx(self):
|
||||
self._onnx = True
|
||||
|
||||
def disable_onnx(self):
|
||||
self._onnx = False
|
||||
|
||||
|
||||
class ScopedOnnxEnable(object):
|
||||
"""
|
||||
Helper scoped ONNX enable class
|
||||
"""
|
||||
def __init__(self, net: gluon.HybridBlock):
|
||||
self._onnx_handlers = self._get_onnx_handlers(net)
|
||||
|
||||
def __enter__(self):
|
||||
for b in self._onnx_handlers:
|
||||
b.enable_onnx()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for b in self._onnx_handlers:
|
||||
b.disable_onnx()
|
||||
|
||||
@staticmethod
|
||||
def _get_onnx_handlers(block: gluon.HybridBlock) -> List[OnnxHandlerBlock]:
|
||||
"""
|
||||
Iterates through all child blocks and return all of them that are instance of OnnxHandlerBlock
|
||||
:return: list of OnnxHandlerBlock child blocks
|
||||
"""
|
||||
handlers = list()
|
||||
if isinstance(block, OnnxHandlerBlock):
|
||||
handlers.append(block)
|
||||
for child_block in block._children.values():
|
||||
handlers += ScopedOnnxEnable._get_onnx_handlers(child_block)
|
||||
return handlers
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import List, Tuple
|
||||
|
||||
from rl_coach.base_parameters import Frameworks, AgentParameters
|
||||
from rl_coach.logger import failed_imports
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
try:
|
||||
import tensorflow as tf
|
||||
@@ -251,3 +252,25 @@ class NetworkWrapper(object):
|
||||
result.append(str(self.online_network))
|
||||
result.append("")
|
||||
return '\n'.join(result)
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collect all of network's savers for global or online network
|
||||
Note: global, online, and target network are all copies fo the same network which parameters that are
|
||||
updated at different rates. So we only need to save one of the networks; the one that holds the most
|
||||
recent parameters. target network is created for some agents and used for stabilizing training by
|
||||
updating parameters from online network at a slower rate. As a result, target network never contains
|
||||
the most recent set of parameters. In single-worker training, no global network is created and online
|
||||
network contains the most recent parameters. In vertical distributed training with more than one worker,
|
||||
global network is updated by all workers and contains the most recent parameters.
|
||||
Therefore preference is given to global network if it exists, otherwise online network is used
|
||||
for saving.
|
||||
:param parent_path_suffix: path suffix of the parent of the network wrapper
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: collection of all checkpoint objects
|
||||
"""
|
||||
if self.global_network:
|
||||
savers = self.global_network.collect_savers(parent_path_suffix)
|
||||
else:
|
||||
savers = self.online_network.collect_savers(parent_path_suffix)
|
||||
return savers
|
||||
|
||||
@@ -23,6 +23,7 @@ import tensorflow as tf
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
|
||||
|
||||
@@ -637,6 +638,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.curr_rnn_c_in = self.middleware.c_init
|
||||
self.curr_rnn_h_in = self.middleware.h_init
|
||||
|
||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||
"""
|
||||
Collection of all checkpoints for the network (typically only one checkpoint)
|
||||
:param parent_path_suffix: path suffix of the parent of the network
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: checkpoint collection for the network
|
||||
"""
|
||||
# TODO implement returning checkpoints for tensorflow
|
||||
return SaverCollection()
|
||||
|
||||
|
||||
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -31,7 +31,8 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T
|
||||
from rl_coach.environments.environment import Environment
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
@@ -87,7 +88,7 @@ class GraphManager(object):
|
||||
schedule_params: ScheduleParameters,
|
||||
vis_params: VisualizationParameters = VisualizationParameters()):
|
||||
self.sess = None
|
||||
self.level_managers = []
|
||||
self.level_managers = [] # type: List[LevelManager]
|
||||
self.top_level_manager = None
|
||||
self.environments = []
|
||||
self.heatup_steps = schedule_params.heatup_steps
|
||||
@@ -248,12 +249,22 @@ class GraphManager(object):
|
||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||
self.save_graph()
|
||||
|
||||
def _create_session_mx(self):
|
||||
"""
|
||||
Call set_session to initialize parameters and construct checkpoint_saver
|
||||
"""
|
||||
self.set_session(sess=None) # Initialize all modules
|
||||
self.checkpoint_saver = SaverCollection()
|
||||
for level in self.level_managers:
|
||||
self.checkpoint_saver.update(level.collect_savers())
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
def create_session(self, task_parameters: TaskParameters):
|
||||
if task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._create_session_tf(task_parameters)
|
||||
elif task_parameters.framework_type == Frameworks.mxnet:
|
||||
self.set_session(sess=None) # Initialize all modules
|
||||
# TODO add checkpoint loading
|
||||
self._create_session_mx()
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
|
||||
|
||||
@@ -270,14 +281,13 @@ class GraphManager(object):
|
||||
name='graphdef.pb',
|
||||
as_text=False)
|
||||
|
||||
def save_onnx_graph(self) -> None:
|
||||
def _save_onnx_graph_tf(self) -> None:
|
||||
"""
|
||||
Save the graph as an ONNX graph.
|
||||
Save the tensorflow graph as an ONNX graph.
|
||||
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
||||
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# collect input and output nodes
|
||||
input_nodes = []
|
||||
output_nodes = []
|
||||
@@ -290,11 +300,20 @@ class GraphManager(object):
|
||||
for output in network.online_network.outputs:
|
||||
output_nodes.append(output.name)
|
||||
|
||||
# TODO: make this framework agnostic
|
||||
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
|
||||
|
||||
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
|
||||
|
||||
def save_onnx_graph(self) -> None:
|
||||
"""
|
||||
Save the graph as an ONNX graph.
|
||||
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
||||
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||
:return: None
|
||||
"""
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._save_onnx_graph_tf()
|
||||
|
||||
def setup_logger(self) -> None:
|
||||
# dump documentation
|
||||
logger_prefix = "{graph_name}".format(graph_name=self.name)
|
||||
@@ -526,14 +545,13 @@ class GraphManager(object):
|
||||
if self.evaluate(self.evaluation_steps):
|
||||
break
|
||||
|
||||
def _restore_checkpoint_tf(self, checkpoint_dir: str):
|
||||
def _restore_checkpoint_tf(self, checkpoint_path: str):
|
||||
import tensorflow as tf
|
||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
variables = {}
|
||||
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
|
||||
reader = tf.contrib.framework.load_checkpoint(checkpoint_path)
|
||||
for var_name, _ in reader.get_variable_to_shape_map().items():
|
||||
# Load the variable
|
||||
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
|
||||
var = reader.get_tensor(var_name)
|
||||
|
||||
# Set the new name
|
||||
new_name = var_name
|
||||
@@ -548,11 +566,14 @@ class GraphManager(object):
|
||||
|
||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
||||
|
||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir)
|
||||
self._restore_checkpoint_tf(checkpoint.model_checkpoint_path)
|
||||
elif self.task_parameters.framework_type == Frameworks.mxnet:
|
||||
# TODO implement checkpoint restore
|
||||
pass
|
||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
|
||||
|
||||
@@ -572,6 +593,8 @@ class GraphManager(object):
|
||||
"{}_Step-{}.ckpt".format(
|
||||
self.checkpoint_id,
|
||||
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
|
||||
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
||||
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
|
||||
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
||||
if self.checkpoint_saver is not None:
|
||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||
|
||||
@@ -20,6 +20,7 @@ from rl_coach.agents.composite_agent import CompositeAgent
|
||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps, Transition
|
||||
from rl_coach.environments.environment import Environment
|
||||
from rl_coach.environments.environment_interface import EnvironmentInterface
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import ActionSpace, SpacesDefinition
|
||||
|
||||
|
||||
@@ -292,3 +293,13 @@ class LevelManager(EnvironmentInterface):
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])
|
||||
|
||||
def collect_savers(self) -> SaverCollection:
|
||||
"""
|
||||
Calls collect_savers() on all agents and combines the results to a single collection
|
||||
:return: saver collection of all agent savers
|
||||
"""
|
||||
savers = SaverCollection()
|
||||
for agent in self.agents.values():
|
||||
savers.update(agent.collect_savers(parent_path_suffix=self.name))
|
||||
return savers
|
||||
|
||||
112
rl_coach/saver.py
Normal file
112
rl_coach/saver.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Module for abstract base class for checkpoint object and checkpoint collection
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class Saver(object):
|
||||
"""
|
||||
ABC for saver objects that implement saving/restoring to/from path, and merging two savers.
|
||||
"""
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two saver objects return the same path, they must be merge-able.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, sess: Any, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, sess: Any, restore_path: str) -> None:
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def merge(self, other: 'Saver') -> None:
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SaverCollection(object):
|
||||
"""
|
||||
Object for storing a collection of saver objects. It takes care of ensuring uniqueness of saver paths
|
||||
and merging savers if they have the same path. For example, if a saver handles saving a generic key/value
|
||||
file for all networks in a single file, it can use a more generic path and all savers of all networks would be
|
||||
merged into a single saver that saves/restores parameters for all networks.
|
||||
NOTE: If two savers have the same path, the respective saver class must support merging them
|
||||
into a single saver that saves/restores all merged parameters.
|
||||
"""
|
||||
def __init__(self, saver: Saver = None):
|
||||
"""
|
||||
:param saver: optional initial saver for the collection
|
||||
"""
|
||||
self._saver_dict = dict() # type: Dict[str, Saver]
|
||||
if saver is not None:
|
||||
self._saver_dict[saver.path] = saver
|
||||
|
||||
def add(self, saver: Saver):
|
||||
"""
|
||||
Add a new saver to the collection. If saver.path is already in the collection, merge
|
||||
the new saver with the existing saver.
|
||||
:param saver: new saver to be added to collection
|
||||
"""
|
||||
if saver.path in self._saver_dict:
|
||||
self._saver_dict[saver.path].merge(saver)
|
||||
else:
|
||||
self._saver_dict[saver.path] = saver
|
||||
|
||||
def update(self, other: 'SaverCollection'):
|
||||
"""
|
||||
Merge savers from other collection into self
|
||||
:param other: saver collection to update self with.
|
||||
"""
|
||||
for c in other:
|
||||
self.add(c)
|
||||
|
||||
def save(self, sess: Any, save_path: str) -> List[str]:
|
||||
"""
|
||||
Call save on all savers in the collection
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param save_path: path for saving checkpoints using savers. All saved file paths must
|
||||
start with this path in their full path. For example if save_path is '/home/checkpoints/checkpoint-01',
|
||||
then saved file paths can be '/home/checkpoints/checkpoint-01.main-network' but not
|
||||
'/home/checkpoints/main-network'
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
paths = list()
|
||||
for saver in self:
|
||||
paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path)))
|
||||
return paths
|
||||
|
||||
def restore(self, sess: Any, restore_path: str) -> None:
|
||||
"""
|
||||
Call restore on all savers in the collection
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: path for restoring checkpoint using savers.
|
||||
"""
|
||||
for saver in self:
|
||||
restore_path = "{}.{}".format(restore_path, saver.path)
|
||||
saver.restore(sess, restore_path)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Return an iterator for savers in the collection
|
||||
:return: saver iterator
|
||||
"""
|
||||
return (v for v in self._saver_dict.values())
|
||||
|
||||
|
||||
|
||||
@@ -142,3 +142,42 @@ def test_hybrid_clip():
|
||||
b = mx.nd.array((2,))
|
||||
clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b)
|
||||
assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all()
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_scoped_onxx_enable():
|
||||
class Counter(object):
|
||||
def __init__(self):
|
||||
self._count = 0
|
||||
|
||||
def increment(self):
|
||||
self._count += 1
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return self._count
|
||||
|
||||
class TempBlock(gluon.HybridBlock, OnnxHandlerBlock):
|
||||
def __init__(self, counter: Counter):
|
||||
super(TempBlock, self).__init__()
|
||||
OnnxHandlerBlock.__init__(self)
|
||||
self._counter = counter
|
||||
|
||||
def hybrid_forward(self, F, x, *args, **kwargs):
|
||||
if self._onnx:
|
||||
self._counter.increment()
|
||||
return x
|
||||
|
||||
counter = Counter()
|
||||
net = gluon.nn.HybridSequential()
|
||||
for _ in range(10):
|
||||
net.add(TempBlock(counter))
|
||||
|
||||
# ONNX disabled
|
||||
net(nd.zeros((1,)))
|
||||
assert counter.count == 0
|
||||
|
||||
# ONNX enabled
|
||||
with ScopedOnnxEnable(net):
|
||||
net(nd.zeros((1,)))
|
||||
assert counter.count == 10
|
||||
|
||||
42
rl_coach/tests/test_saver.py
Normal file
42
rl_coach/tests/test_saver.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from rl_coach.saver import Saver, SaverCollection
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_checkpoint_collection():
|
||||
class SaverTest(Saver):
|
||||
def __init__(self, path):
|
||||
self._path = path
|
||||
self._count = 1
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._path
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
assert isinstance(other, SaverTest)
|
||||
assert self.path == other.path
|
||||
self._count += other._count
|
||||
|
||||
# test add
|
||||
savers = SaverCollection(SaverTest('123'))
|
||||
savers.add(SaverTest('123'))
|
||||
savers.add(SaverTest('456'))
|
||||
|
||||
def check_collection(mul):
|
||||
paths = ['123', '456']
|
||||
for c in savers:
|
||||
paths.remove(c.path)
|
||||
if c.path == '123':
|
||||
assert c._count == 2 * mul
|
||||
elif c.path == '456':
|
||||
assert c._count == 1 * mul
|
||||
else:
|
||||
assert False, "invalid path"
|
||||
|
||||
check_collection(1)
|
||||
|
||||
# test update
|
||||
savers.update(savers)
|
||||
check_collection(2)
|
||||
21
rl_coach/tests/test_utils.py
Normal file
21
rl_coach/tests/test_utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from rl_coach import utils
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_checkpoint_state_default():
|
||||
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext']
|
||||
checkpoint_state = utils.get_checkpoint_state(files)
|
||||
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
|
||||
assert checkpoint_state.all_model_checkpoint_paths == [f[:-4] for f in sorted(files)]
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_checkpoint_state_custom():
|
||||
files = ['prefix.4.test.ckpt.ext', 'prefix.2.test.ckpt.ext', 'prefix.3.test.ckpt.ext', 'prefix.1.test.ckpt.ext']
|
||||
assert len(utils.get_checkpoint_state(files).all_model_checkpoint_paths) == 0 # doesn't match the default pattern
|
||||
checkpoint_state = utils.get_checkpoint_state(files, filename_pattern=r'([0-9]+)[^0-9].*?\.ckpt')
|
||||
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
|
||||
assert checkpoint_state.all_model_checkpoint_paths == [f[7:-4] for f in sorted(files)]
|
||||
|
||||
@@ -19,6 +19,7 @@ import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
@@ -26,7 +27,7 @@ import time
|
||||
import traceback
|
||||
from multiprocessing import Manager
|
||||
from subprocess import Popen
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import atexit
|
||||
import numpy as np
|
||||
@@ -547,3 +548,50 @@ def indent_string(string):
|
||||
return '\t' + string.replace('\n', '\n\t')
|
||||
|
||||
|
||||
class CheckpointState(object):
|
||||
"""
|
||||
Helper class for checkpoint directory information. It replicates
|
||||
the CheckpointState protobuf class in tensorflow.
|
||||
"""
|
||||
def __init__(self, checkpoints: List[str]):
|
||||
self._checkpoints = checkpoints
|
||||
|
||||
@property
|
||||
def all_model_checkpoint_paths(self):
|
||||
return self._checkpoints
|
||||
|
||||
@property
|
||||
def model_checkpoint_path(self):
|
||||
return self._checkpoints[-1]
|
||||
|
||||
def __str__(self):
|
||||
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
|
||||
for c in self._checkpoints:
|
||||
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
|
||||
return out_str
|
||||
|
||||
def __repr__(self):
|
||||
return str(self._checkpoints)
|
||||
|
||||
|
||||
COACH_CHECKPOINT_PATTERN = r'\A([0-9]+)[^0-9].*?\.ckpt'
|
||||
|
||||
|
||||
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
|
||||
CheckpointState:
|
||||
"""
|
||||
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(1)) to sort
|
||||
the checkpoint names and find the latest checkpoint
|
||||
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
|
||||
:param filename_pattern: regex pattern for checkpoint filenames
|
||||
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names
|
||||
"""
|
||||
prog = re.compile(filename_pattern)
|
||||
checkpoints = dict()
|
||||
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
|
||||
for name in filenames:
|
||||
m = prog.search(name)
|
||||
if m is not None and m.group(1) is not None:
|
||||
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
|
||||
checkpoints[int(m.group(1))] = full_path
|
||||
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])
|
||||
|
||||
Reference in New Issue
Block a user