mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -30,6 +30,7 @@ from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, A
|
|||||||
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
|
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
|
||||||
from rl_coach.logger import screen, Logger, EpisodeLogger
|
from rl_coach.logger import screen, Logger, EpisodeLogger
|
||||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplay
|
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplay
|
||||||
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
|
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
|
||||||
from rl_coach.utils import Signal, force_list
|
from rl_coach.utils import Signal, force_list
|
||||||
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
||||||
@@ -996,3 +997,16 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
def get_success_rate(self) -> float:
|
def get_success_rate(self) -> float:
|
||||||
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||||
|
|
||||||
|
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||||
|
"""
|
||||||
|
Collect all of agent's network savers
|
||||||
|
:param parent_path_suffix: path suffix of the parent of the agent
|
||||||
|
(could be name of level manager or composite agent)
|
||||||
|
:return: collection of all agent savers
|
||||||
|
"""
|
||||||
|
parent_path_suffix = "{}.{}".format(parent_path_suffix, self.name)
|
||||||
|
savers = SaverCollection()
|
||||||
|
for network in self.networks.values():
|
||||||
|
savers.update(network.collect_savers(parent_path_suffix))
|
||||||
|
return savers
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import Union, List, Dict
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition
|
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition
|
||||||
|
from rl_coach.saver import SaverCollection
|
||||||
|
|
||||||
|
|
||||||
class AgentInterface(object):
|
class AgentInterface(object):
|
||||||
@@ -153,3 +154,12 @@ class AgentInterface(object):
|
|||||||
:return: A tuple containing the actual action and additional info on the action
|
:return: A tuple containing the actual action and additional info on the action
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("")
|
raise NotImplementedError("")
|
||||||
|
|
||||||
|
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||||
|
"""
|
||||||
|
Collect all of agent savers
|
||||||
|
:param parent_path_suffix: path suffix of the parent of the agent
|
||||||
|
(could be name of level manager or composite agent)
|
||||||
|
:return: collection of all agent savers
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from rl_coach.agents.agent_interface import AgentInterface
|
|||||||
from rl_coach.base_parameters import AgentParameters, VisualizationParameters
|
from rl_coach.base_parameters import AgentParameters, VisualizationParameters
|
||||||
from rl_coach.core_types import ActionInfo, EnvResponse, ActionType, RunPhase
|
from rl_coach.core_types import ActionInfo, EnvResponse, ActionType, RunPhase
|
||||||
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||||
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import ActionSpace
|
from rl_coach.spaces import ActionSpace
|
||||||
from rl_coach.spaces import AgentSelection, AttentionActionSpace, SpacesDefinition
|
from rl_coach.spaces import AgentSelection, AttentionActionSpace, SpacesDefinition
|
||||||
from rl_coach.utils import short_dynamic_import
|
from rl_coach.utils import short_dynamic_import
|
||||||
@@ -412,3 +413,16 @@ class CompositeAgent(AgentInterface):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
[agent.sync() for agent in self.agents.values()]
|
[agent.sync() for agent in self.agents.values()]
|
||||||
|
|
||||||
|
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||||
|
"""
|
||||||
|
Collect all of agent's network savers
|
||||||
|
:param parent_path_suffix: path suffix of the parent of the agent
|
||||||
|
(could be name of level manager or composite agent)
|
||||||
|
:return: collection of all agent savers
|
||||||
|
"""
|
||||||
|
savers = SaverCollection()
|
||||||
|
for agent in self.agents.values():
|
||||||
|
savers.update(agent.collect_savers(
|
||||||
|
parent_path_suffix="{}.{}".format(parent_path_suffix, self.name)))
|
||||||
|
return savers
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.base_parameters import AgentParameters
|
from rl_coach.base_parameters import AgentParameters
|
||||||
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import SpacesDefinition
|
from rl_coach.spaces import SpacesDefinition
|
||||||
|
|
||||||
|
|
||||||
@@ -213,3 +214,12 @@ class Architecture(object):
|
|||||||
:param placeholder: a placeholder for binding the value to assign_op.
|
:param placeholder: a placeholder for binding the value to assign_op.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||||
|
"""
|
||||||
|
Collection of all savers for the network (typically only one saver for network and one for ONNX export)
|
||||||
|
:param parent_path_suffix: path suffix of the parent of the network
|
||||||
|
(e.g. could be name of level manager plus name of agent)
|
||||||
|
:return: saver collection for the network
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -24,7 +24,9 @@ from mxnet.ndarray import NDArray
|
|||||||
from rl_coach.architectures.architecture import Architecture
|
from rl_coach.architectures.architecture import Architecture
|
||||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||||
from rl_coach.architectures.mxnet_components import utils
|
from rl_coach.architectures.mxnet_components import utils
|
||||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
|
||||||
|
from rl_coach.base_parameters import AgentParameters
|
||||||
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import SpacesDefinition
|
from rl_coach.spaces import SpacesDefinition
|
||||||
from rl_coach.utils import force_list, squeeze_list
|
from rl_coach.utils import force_list, squeeze_list
|
||||||
|
|
||||||
@@ -81,17 +83,25 @@ class MxnetArchitecture(Architecture):
|
|||||||
"""
|
"""
|
||||||
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||||
|
|
||||||
|
def _model_input_shapes(self) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Create a list of input array shapes
|
||||||
|
:return: type of input shapes
|
||||||
|
"""
|
||||||
|
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
||||||
|
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
||||||
|
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
||||||
|
embedders = self.model.nets[0].input_embedders
|
||||||
|
return list([1] + allowed_inputs[emb.embedder_name].shape.tolist() for emb in embedders)
|
||||||
|
|
||||||
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
|
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
|
||||||
"""
|
"""
|
||||||
Creates a tuple of input arrays with correct shapes that can be used for shape inference
|
Creates a tuple of input arrays with correct shapes that can be used for shape inference
|
||||||
of the model weights and for printing the summary
|
of the model weights and for printing the summary
|
||||||
:return: tuple of inputs for model forward pass
|
:return: tuple of inputs for model forward pass
|
||||||
"""
|
"""
|
||||||
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
input_shapes = self._model_input_shapes()
|
||||||
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
|
||||||
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
|
||||||
embedders = self.model.nets[0].input_embedders
|
|
||||||
inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders)
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def construct_model(self) -> None:
|
def construct_model(self) -> None:
|
||||||
@@ -402,3 +412,21 @@ class MxnetArchitecture(Architecture):
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
|
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
|
||||||
|
|
||||||
|
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||||
|
"""
|
||||||
|
Collection of all checkpoints for the network (typically only one checkpoint)
|
||||||
|
:param parent_path_suffix: path suffix of the parent of the network
|
||||||
|
(e.g. could be name of level manager plus name of agent)
|
||||||
|
:return: checkpoint collection for the network
|
||||||
|
"""
|
||||||
|
name = self.name.replace('/', '.')
|
||||||
|
savers = SaverCollection(ParameterDictSaver(
|
||||||
|
name="{}.{}".format(parent_path_suffix, name),
|
||||||
|
param_dict=self.model.collect_params()))
|
||||||
|
if self.ap.task_parameters.export_onnx_graph:
|
||||||
|
savers.add(OnnxSaver(
|
||||||
|
name="{}.{}.onnx".format(parent_path_suffix, name),
|
||||||
|
model=self.model,
|
||||||
|
input_shapes=self._model_input_shapes()))
|
||||||
|
return savers
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ def _get_output_head(
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
class ScaledGradHead(HybridBlock):
|
class ScaledGradHead(HybridBlock, utils.OnnxHandlerBlock):
|
||||||
"""
|
"""
|
||||||
Wrapper block for applying gradient scaling to input before feeding the head network
|
Wrapper block for applying gradient scaling to input before feeding the head network
|
||||||
"""
|
"""
|
||||||
@@ -292,7 +292,7 @@ class ScaledGradHead(HybridBlock):
|
|||||||
agent_params: AgentParameters,
|
agent_params: AgentParameters,
|
||||||
head_params: HeadParameters) -> None:
|
head_params: HeadParameters) -> None:
|
||||||
"""
|
"""
|
||||||
:param head_idx: the head index
|
:param head_index: the head index
|
||||||
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
|
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
|
||||||
:param network_name: name of the network
|
:param network_name: name of the network
|
||||||
:param spaces: state and action space definitions
|
:param spaces: state and action space definitions
|
||||||
@@ -301,6 +301,7 @@ class ScaledGradHead(HybridBlock):
|
|||||||
:param head_params: head parameters
|
:param head_params: head parameters
|
||||||
"""
|
"""
|
||||||
super(ScaledGradHead, self).__init__()
|
super(ScaledGradHead, self).__init__()
|
||||||
|
utils.OnnxHandlerBlock.__init__(self)
|
||||||
|
|
||||||
head_params = _sanitize_activation(head_params)
|
head_params = _sanitize_activation(head_params)
|
||||||
with self.name_scope():
|
with self.name_scope():
|
||||||
@@ -330,7 +331,13 @@ class ScaledGradHead(HybridBlock):
|
|||||||
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
|
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
|
||||||
:return: head output
|
:return: head output
|
||||||
"""
|
"""
|
||||||
grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x)
|
if self._onnx:
|
||||||
|
# ONNX doesn't support BlockGrad() operator, but it's not typically needed for
|
||||||
|
# ONNX because mostly forward calls are performed using ONNX exported network.
|
||||||
|
grad_scaled_x = x
|
||||||
|
else:
|
||||||
|
grad_scaled_x = (F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) +
|
||||||
|
F.broadcast_mul(gradient_rescaler, x))
|
||||||
out = self.head(grad_scaled_x)
|
out = self.head(grad_scaled_x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class PPOVHead(Head):
|
|||||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||||
:return: final value output of network, of shape (batch_size).
|
:return: final value output of network, of shape (batch_size).
|
||||||
"""
|
"""
|
||||||
return self.dense(x).squeeze()
|
return self.dense(x).squeeze(axis=1)
|
||||||
|
|
||||||
def loss(self) -> mx.gluon.loss.Loss:
|
def loss(self) -> mx.gluon.loss.Loss:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -98,4 +98,4 @@ class VHead(Head):
|
|||||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||||
:return: final output of value network, of shape (batch_size).
|
:return: final output of value network, of shape (batch_size).
|
||||||
"""
|
"""
|
||||||
return self.dense(x).squeeze()
|
return self.dense(x).squeeze(axis=1)
|
||||||
|
|||||||
113
rl_coach/architectures/mxnet_components/savers.py
Normal file
113
rl_coach/architectures/mxnet_components/savers.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
from mxnet import gluon, sym
|
||||||
|
from mxnet.contrib import onnx as onnx_mxnet
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from rl_coach.architectures.mxnet_components.utils import ScopedOnnxEnable
|
||||||
|
from rl_coach.saver import Saver
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterDictSaver(Saver):
|
||||||
|
"""
|
||||||
|
Child class that implements saver for mxnet gluon parameter dictionary
|
||||||
|
"""
|
||||||
|
def __init__(self, name: str, param_dict: gluon.ParameterDict):
|
||||||
|
self._name = name
|
||||||
|
self._param_dict = param_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
"""
|
||||||
|
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||||
|
"""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def save(self, sess: None, save_path: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Save to save_path
|
||||||
|
:param sess: active session for session-based frameworks (e.g. TF)
|
||||||
|
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||||
|
:return: list of all saved paths
|
||||||
|
"""
|
||||||
|
assert sess is None
|
||||||
|
self._param_dict.save(save_path)
|
||||||
|
return [save_path]
|
||||||
|
|
||||||
|
def restore(self, sess: Any, restore_path: str):
|
||||||
|
"""
|
||||||
|
Restore from restore_path
|
||||||
|
:param sess: active session for session-based frameworks (e.g. TF)
|
||||||
|
:param restore_path: full path to load checkpoint from.
|
||||||
|
"""
|
||||||
|
assert sess is None
|
||||||
|
self._param_dict.load(restore_path)
|
||||||
|
|
||||||
|
def merge(self, other: 'Saver'):
|
||||||
|
"""
|
||||||
|
Merge other saver into this saver
|
||||||
|
:param other: saver to be merged into self
|
||||||
|
"""
|
||||||
|
if not isinstance(other, ParameterDictSaver):
|
||||||
|
raise TypeError('merging only supported with ParameterDictSaver (type:{})'.format(type(other)))
|
||||||
|
self._param_dict.update(other._param_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxSaver(Saver):
|
||||||
|
"""
|
||||||
|
Child class that implements saver for exporting gluon HybridBlock to ONNX
|
||||||
|
"""
|
||||||
|
def __init__(self, name: str, model: gluon.HybridBlock, input_shapes: List[List[int]]):
|
||||||
|
self._name = name
|
||||||
|
self._sym = self._get_onnx_sym(model, len(input_shapes))
|
||||||
|
self._param_dict = model.collect_params()
|
||||||
|
self._input_shapes = input_shapes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_onnx_sym(model: gluon.HybridBlock, num_inputs: int) -> sym.Symbol:
|
||||||
|
"""
|
||||||
|
Returns a symbolic graph for the model
|
||||||
|
:param model: gluon HybridBlock that constructs the symbolic graph
|
||||||
|
:param num_inputs: number of inputs to the graph
|
||||||
|
:return: symbol for the network
|
||||||
|
"""
|
||||||
|
var_args = [sym.Variable('Data{}'.format(i)) for i in range(num_inputs)]
|
||||||
|
with ScopedOnnxEnable(model):
|
||||||
|
return sym.Group(gluon.block._flatten(model(*var_args), "output")[0])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
"""
|
||||||
|
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||||
|
"""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def save(self, sess: None, save_path: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Save to save_path
|
||||||
|
:param sess: active session for session-based frameworks (e.g. TF). Must be None.
|
||||||
|
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||||
|
:return: list of all saved paths
|
||||||
|
"""
|
||||||
|
assert sess is None
|
||||||
|
params = {name:param._reduce() for name, param in self._param_dict.items()}
|
||||||
|
export_path = onnx_mxnet.export_model(self._sym, params, self._input_shapes, np.float32, save_path)
|
||||||
|
|
||||||
|
return [export_path]
|
||||||
|
|
||||||
|
def restore(self, sess: Any, restore_path: str):
|
||||||
|
"""
|
||||||
|
Restore from restore_path
|
||||||
|
:param sess: active session for session-based frameworks (e.g. TF)
|
||||||
|
:param restore_path: full path to load checkpoint from.
|
||||||
|
"""
|
||||||
|
assert sess is None
|
||||||
|
# Nothing to restore for ONNX
|
||||||
|
|
||||||
|
def merge(self, other: 'Saver'):
|
||||||
|
"""
|
||||||
|
Merge other saver into this saver
|
||||||
|
:param other: saver to be merged into self
|
||||||
|
"""
|
||||||
|
# No merging is supported for ONNX. self.path must be unique
|
||||||
|
raise RuntimeError('merging not supported for ONNX exporter')
|
||||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
|
|||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
import mxnet as mx
|
import mxnet as mx
|
||||||
from mxnet import nd
|
from mxnet import gluon, nd
|
||||||
from mxnet.ndarray import NDArray
|
from mxnet.ndarray import NDArray
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -278,3 +278,46 @@ def get_mxnet_activation_name(activation_name: str):
|
|||||||
"Activation function must be one of the following {}. instead it was: {}".format(
|
"Activation function must be one of the following {}. instead it was: {}".format(
|
||||||
activation_functions.keys(), activation_name)
|
activation_functions.keys(), activation_name)
|
||||||
return activation_functions[activation_name]
|
return activation_functions[activation_name]
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxHandlerBlock(object):
|
||||||
|
"""
|
||||||
|
Helper base class for gluon blocks that must behave differently for ONNX export forward pass
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self._onnx = False
|
||||||
|
|
||||||
|
def enable_onnx(self):
|
||||||
|
self._onnx = True
|
||||||
|
|
||||||
|
def disable_onnx(self):
|
||||||
|
self._onnx = False
|
||||||
|
|
||||||
|
|
||||||
|
class ScopedOnnxEnable(object):
|
||||||
|
"""
|
||||||
|
Helper scoped ONNX enable class
|
||||||
|
"""
|
||||||
|
def __init__(self, net: gluon.HybridBlock):
|
||||||
|
self._onnx_handlers = self._get_onnx_handlers(net)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
for b in self._onnx_handlers:
|
||||||
|
b.enable_onnx()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
for b in self._onnx_handlers:
|
||||||
|
b.disable_onnx()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_onnx_handlers(block: gluon.HybridBlock) -> List[OnnxHandlerBlock]:
|
||||||
|
"""
|
||||||
|
Iterates through all child blocks and return all of them that are instance of OnnxHandlerBlock
|
||||||
|
:return: list of OnnxHandlerBlock child blocks
|
||||||
|
"""
|
||||||
|
handlers = list()
|
||||||
|
if isinstance(block, OnnxHandlerBlock):
|
||||||
|
handlers.append(block)
|
||||||
|
for child_block in block._children.values():
|
||||||
|
handlers += ScopedOnnxEnable._get_onnx_handlers(child_block)
|
||||||
|
return handlers
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
from rl_coach.base_parameters import Frameworks, AgentParameters
|
from rl_coach.base_parameters import Frameworks, AgentParameters
|
||||||
from rl_coach.logger import failed_imports
|
from rl_coach.logger import failed_imports
|
||||||
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import SpacesDefinition
|
from rl_coach.spaces import SpacesDefinition
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -251,3 +252,25 @@ class NetworkWrapper(object):
|
|||||||
result.append(str(self.online_network))
|
result.append(str(self.online_network))
|
||||||
result.append("")
|
result.append("")
|
||||||
return '\n'.join(result)
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||||
|
"""
|
||||||
|
Collect all of network's savers for global or online network
|
||||||
|
Note: global, online, and target network are all copies fo the same network which parameters that are
|
||||||
|
updated at different rates. So we only need to save one of the networks; the one that holds the most
|
||||||
|
recent parameters. target network is created for some agents and used for stabilizing training by
|
||||||
|
updating parameters from online network at a slower rate. As a result, target network never contains
|
||||||
|
the most recent set of parameters. In single-worker training, no global network is created and online
|
||||||
|
network contains the most recent parameters. In vertical distributed training with more than one worker,
|
||||||
|
global network is updated by all workers and contains the most recent parameters.
|
||||||
|
Therefore preference is given to global network if it exists, otherwise online network is used
|
||||||
|
for saving.
|
||||||
|
:param parent_path_suffix: path suffix of the parent of the network wrapper
|
||||||
|
(e.g. could be name of level manager plus name of agent)
|
||||||
|
:return: collection of all checkpoint objects
|
||||||
|
"""
|
||||||
|
if self.global_network:
|
||||||
|
savers = self.global_network.collect_savers(parent_path_suffix)
|
||||||
|
else:
|
||||||
|
savers = self.online_network.collect_savers(parent_path_suffix)
|
||||||
|
return savers
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import tensorflow as tf
|
|||||||
from rl_coach.architectures.architecture import Architecture
|
from rl_coach.architectures.architecture import Architecture
|
||||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||||
from rl_coach.core_types import GradientClippingMethod
|
from rl_coach.core_types import GradientClippingMethod
|
||||||
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import SpacesDefinition
|
from rl_coach.spaces import SpacesDefinition
|
||||||
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
|
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
|
||||||
|
|
||||||
@@ -637,6 +638,16 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
self.curr_rnn_c_in = self.middleware.c_init
|
self.curr_rnn_c_in = self.middleware.c_init
|
||||||
self.curr_rnn_h_in = self.middleware.h_init
|
self.curr_rnn_h_in = self.middleware.h_init
|
||||||
|
|
||||||
|
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||||
|
"""
|
||||||
|
Collection of all checkpoints for the network (typically only one checkpoint)
|
||||||
|
:param parent_path_suffix: path suffix of the parent of the network
|
||||||
|
(e.g. could be name of level manager plus name of agent)
|
||||||
|
:return: checkpoint collection for the network
|
||||||
|
"""
|
||||||
|
# TODO implement returning checkpoints for tensorflow
|
||||||
|
return SaverCollection()
|
||||||
|
|
||||||
|
|
||||||
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T
|
|||||||
from rl_coach.environments.environment import Environment
|
from rl_coach.environments.environment import Environment
|
||||||
from rl_coach.level_manager import LevelManager
|
from rl_coach.level_manager import LevelManager
|
||||||
from rl_coach.logger import screen, Logger
|
from rl_coach.logger import screen, Logger
|
||||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
from rl_coach.saver import SaverCollection
|
||||||
|
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||||
from rl_coach.data_stores.data_store import SyncFiles
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
@@ -87,7 +88,7 @@ class GraphManager(object):
|
|||||||
schedule_params: ScheduleParameters,
|
schedule_params: ScheduleParameters,
|
||||||
vis_params: VisualizationParameters = VisualizationParameters()):
|
vis_params: VisualizationParameters = VisualizationParameters()):
|
||||||
self.sess = None
|
self.sess = None
|
||||||
self.level_managers = []
|
self.level_managers = [] # type: List[LevelManager]
|
||||||
self.top_level_manager = None
|
self.top_level_manager = None
|
||||||
self.environments = []
|
self.environments = []
|
||||||
self.heatup_steps = schedule_params.heatup_steps
|
self.heatup_steps = schedule_params.heatup_steps
|
||||||
@@ -248,12 +249,22 @@ class GraphManager(object):
|
|||||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||||
self.save_graph()
|
self.save_graph()
|
||||||
|
|
||||||
|
def _create_session_mx(self):
|
||||||
|
"""
|
||||||
|
Call set_session to initialize parameters and construct checkpoint_saver
|
||||||
|
"""
|
||||||
|
self.set_session(sess=None) # Initialize all modules
|
||||||
|
self.checkpoint_saver = SaverCollection()
|
||||||
|
for level in self.level_managers:
|
||||||
|
self.checkpoint_saver.update(level.collect_savers())
|
||||||
|
# restore from checkpoint if given
|
||||||
|
self.restore_checkpoint()
|
||||||
|
|
||||||
def create_session(self, task_parameters: TaskParameters):
|
def create_session(self, task_parameters: TaskParameters):
|
||||||
if task_parameters.framework_type == Frameworks.tensorflow:
|
if task_parameters.framework_type == Frameworks.tensorflow:
|
||||||
self._create_session_tf(task_parameters)
|
self._create_session_tf(task_parameters)
|
||||||
elif task_parameters.framework_type == Frameworks.mxnet:
|
elif task_parameters.framework_type == Frameworks.mxnet:
|
||||||
self.set_session(sess=None) # Initialize all modules
|
self._create_session_mx()
|
||||||
# TODO add checkpoint loading
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
|
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
|
||||||
|
|
||||||
@@ -270,14 +281,13 @@ class GraphManager(object):
|
|||||||
name='graphdef.pb',
|
name='graphdef.pb',
|
||||||
as_text=False)
|
as_text=False)
|
||||||
|
|
||||||
def save_onnx_graph(self) -> None:
|
def _save_onnx_graph_tf(self) -> None:
|
||||||
"""
|
"""
|
||||||
Save the graph as an ONNX graph.
|
Save the tensorflow graph as an ONNX graph.
|
||||||
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
||||||
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# collect input and output nodes
|
# collect input and output nodes
|
||||||
input_nodes = []
|
input_nodes = []
|
||||||
output_nodes = []
|
output_nodes = []
|
||||||
@@ -290,11 +300,20 @@ class GraphManager(object):
|
|||||||
for output in network.online_network.outputs:
|
for output in network.online_network.outputs:
|
||||||
output_nodes.append(output.name)
|
output_nodes.append(output.name)
|
||||||
|
|
||||||
# TODO: make this framework agnostic
|
|
||||||
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
|
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
|
||||||
|
|
||||||
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
|
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
|
||||||
|
|
||||||
|
def save_onnx_graph(self) -> None:
|
||||||
|
"""
|
||||||
|
Save the graph as an ONNX graph.
|
||||||
|
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
||||||
|
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||||
|
self._save_onnx_graph_tf()
|
||||||
|
|
||||||
def setup_logger(self) -> None:
|
def setup_logger(self) -> None:
|
||||||
# dump documentation
|
# dump documentation
|
||||||
logger_prefix = "{graph_name}".format(graph_name=self.name)
|
logger_prefix = "{graph_name}".format(graph_name=self.name)
|
||||||
@@ -526,14 +545,13 @@ class GraphManager(object):
|
|||||||
if self.evaluate(self.evaluation_steps):
|
if self.evaluate(self.evaluation_steps):
|
||||||
break
|
break
|
||||||
|
|
||||||
def _restore_checkpoint_tf(self, checkpoint_dir: str):
|
def _restore_checkpoint_tf(self, checkpoint_path: str):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
|
|
||||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
|
||||||
variables = {}
|
variables = {}
|
||||||
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
|
reader = tf.contrib.framework.load_checkpoint(checkpoint_path)
|
||||||
|
for var_name, _ in reader.get_variable_to_shape_map().items():
|
||||||
# Load the variable
|
# Load the variable
|
||||||
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
|
var = reader.get_tensor(var_name)
|
||||||
|
|
||||||
# Set the new name
|
# Set the new name
|
||||||
new_name = var_name
|
new_name = var_name
|
||||||
@@ -548,11 +566,14 @@ class GraphManager(object):
|
|||||||
|
|
||||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||||
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
||||||
|
|
||||||
|
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||||
|
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||||
|
|
||||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||||
self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir)
|
self._restore_checkpoint_tf(checkpoint.model_checkpoint_path)
|
||||||
elif self.task_parameters.framework_type == Frameworks.mxnet:
|
elif self.task_parameters.framework_type == Frameworks.mxnet:
|
||||||
# TODO implement checkpoint restore
|
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
|
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
|
||||||
|
|
||||||
@@ -572,6 +593,8 @@ class GraphManager(object):
|
|||||||
"{}_Step-{}.ckpt".format(
|
"{}_Step-{}.ckpt".format(
|
||||||
self.checkpoint_id,
|
self.checkpoint_id,
|
||||||
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
|
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
|
||||||
|
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
||||||
|
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
|
||||||
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
||||||
if self.checkpoint_saver is not None:
|
if self.checkpoint_saver is not None:
|
||||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from rl_coach.agents.composite_agent import CompositeAgent
|
|||||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps, Transition
|
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps, Transition
|
||||||
from rl_coach.environments.environment import Environment
|
from rl_coach.environments.environment import Environment
|
||||||
from rl_coach.environments.environment_interface import EnvironmentInterface
|
from rl_coach.environments.environment_interface import EnvironmentInterface
|
||||||
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import ActionSpace, SpacesDefinition
|
from rl_coach.spaces import ActionSpace, SpacesDefinition
|
||||||
|
|
||||||
|
|
||||||
@@ -292,3 +293,13 @@ class LevelManager(EnvironmentInterface):
|
|||||||
|
|
||||||
def should_stop(self) -> bool:
|
def should_stop(self) -> bool:
|
||||||
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])
|
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])
|
||||||
|
|
||||||
|
def collect_savers(self) -> SaverCollection:
|
||||||
|
"""
|
||||||
|
Calls collect_savers() on all agents and combines the results to a single collection
|
||||||
|
:return: saver collection of all agent savers
|
||||||
|
"""
|
||||||
|
savers = SaverCollection()
|
||||||
|
for agent in self.agents.values():
|
||||||
|
savers.update(agent.collect_savers(parent_path_suffix=self.name))
|
||||||
|
return savers
|
||||||
|
|||||||
112
rl_coach/saver.py
Normal file
112
rl_coach/saver.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""
|
||||||
|
Module for abstract base class for checkpoint object and checkpoint collection
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
class Saver(object):
|
||||||
|
"""
|
||||||
|
ABC for saver objects that implement saving/restoring to/from path, and merging two savers.
|
||||||
|
"""
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
"""
|
||||||
|
Relative path for save/load. If two saver objects return the same path, they must be merge-able.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def save(self, sess: Any, save_path: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Save to save_path
|
||||||
|
:param sess: active session for session-based frameworks (e.g. TF)
|
||||||
|
:param save_path: full path to save checkpoint (typically directory plus self.path plus checkpoint count).
|
||||||
|
:return: list of all saved paths
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def restore(self, sess: Any, restore_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Restore from restore_path
|
||||||
|
:param sess: active session for session-based frameworks (e.g. TF)
|
||||||
|
:param restore_path: full path to load checkpoint from.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def merge(self, other: 'Saver') -> None:
|
||||||
|
"""
|
||||||
|
Merge other saver into this saver
|
||||||
|
:param other: saver to be merged into self
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SaverCollection(object):
|
||||||
|
"""
|
||||||
|
Object for storing a collection of saver objects. It takes care of ensuring uniqueness of saver paths
|
||||||
|
and merging savers if they have the same path. For example, if a saver handles saving a generic key/value
|
||||||
|
file for all networks in a single file, it can use a more generic path and all savers of all networks would be
|
||||||
|
merged into a single saver that saves/restores parameters for all networks.
|
||||||
|
NOTE: If two savers have the same path, the respective saver class must support merging them
|
||||||
|
into a single saver that saves/restores all merged parameters.
|
||||||
|
"""
|
||||||
|
def __init__(self, saver: Saver = None):
|
||||||
|
"""
|
||||||
|
:param saver: optional initial saver for the collection
|
||||||
|
"""
|
||||||
|
self._saver_dict = dict() # type: Dict[str, Saver]
|
||||||
|
if saver is not None:
|
||||||
|
self._saver_dict[saver.path] = saver
|
||||||
|
|
||||||
|
def add(self, saver: Saver):
|
||||||
|
"""
|
||||||
|
Add a new saver to the collection. If saver.path is already in the collection, merge
|
||||||
|
the new saver with the existing saver.
|
||||||
|
:param saver: new saver to be added to collection
|
||||||
|
"""
|
||||||
|
if saver.path in self._saver_dict:
|
||||||
|
self._saver_dict[saver.path].merge(saver)
|
||||||
|
else:
|
||||||
|
self._saver_dict[saver.path] = saver
|
||||||
|
|
||||||
|
def update(self, other: 'SaverCollection'):
|
||||||
|
"""
|
||||||
|
Merge savers from other collection into self
|
||||||
|
:param other: saver collection to update self with.
|
||||||
|
"""
|
||||||
|
for c in other:
|
||||||
|
self.add(c)
|
||||||
|
|
||||||
|
def save(self, sess: Any, save_path: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Call save on all savers in the collection
|
||||||
|
:param sess: active session for session-based frameworks (e.g. TF)
|
||||||
|
:param save_path: path for saving checkpoints using savers. All saved file paths must
|
||||||
|
start with this path in their full path. For example if save_path is '/home/checkpoints/checkpoint-01',
|
||||||
|
then saved file paths can be '/home/checkpoints/checkpoint-01.main-network' but not
|
||||||
|
'/home/checkpoints/main-network'
|
||||||
|
:return: list of all saved paths
|
||||||
|
"""
|
||||||
|
paths = list()
|
||||||
|
for saver in self:
|
||||||
|
paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path)))
|
||||||
|
return paths
|
||||||
|
|
||||||
|
def restore(self, sess: Any, restore_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Call restore on all savers in the collection
|
||||||
|
:param sess: active session for session-based frameworks (e.g. TF)
|
||||||
|
:param restore_path: path for restoring checkpoint using savers.
|
||||||
|
"""
|
||||||
|
for saver in self:
|
||||||
|
restore_path = "{}.{}".format(restore_path, saver.path)
|
||||||
|
saver.restore(sess, restore_path)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Return an iterator for savers in the collection
|
||||||
|
:return: saver iterator
|
||||||
|
"""
|
||||||
|
return (v for v in self._saver_dict.values())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -142,3 +142,42 @@ def test_hybrid_clip():
|
|||||||
b = mx.nd.array((2,))
|
b = mx.nd.array((2,))
|
||||||
clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b)
|
clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b)
|
||||||
assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all()
|
assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_scoped_onxx_enable():
|
||||||
|
class Counter(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._count = 0
|
||||||
|
|
||||||
|
def increment(self):
|
||||||
|
self._count += 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self):
|
||||||
|
return self._count
|
||||||
|
|
||||||
|
class TempBlock(gluon.HybridBlock, OnnxHandlerBlock):
|
||||||
|
def __init__(self, counter: Counter):
|
||||||
|
super(TempBlock, self).__init__()
|
||||||
|
OnnxHandlerBlock.__init__(self)
|
||||||
|
self._counter = counter
|
||||||
|
|
||||||
|
def hybrid_forward(self, F, x, *args, **kwargs):
|
||||||
|
if self._onnx:
|
||||||
|
self._counter.increment()
|
||||||
|
return x
|
||||||
|
|
||||||
|
counter = Counter()
|
||||||
|
net = gluon.nn.HybridSequential()
|
||||||
|
for _ in range(10):
|
||||||
|
net.add(TempBlock(counter))
|
||||||
|
|
||||||
|
# ONNX disabled
|
||||||
|
net(nd.zeros((1,)))
|
||||||
|
assert counter.count == 0
|
||||||
|
|
||||||
|
# ONNX enabled
|
||||||
|
with ScopedOnnxEnable(net):
|
||||||
|
net(nd.zeros((1,)))
|
||||||
|
assert counter.count == 10
|
||||||
|
|||||||
42
rl_coach/tests/test_saver.py
Normal file
42
rl_coach/tests/test_saver.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from rl_coach.saver import Saver, SaverCollection
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_checkpoint_collection():
|
||||||
|
class SaverTest(Saver):
|
||||||
|
def __init__(self, path):
|
||||||
|
self._path = path
|
||||||
|
self._count = 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
def merge(self, other: 'Saver'):
|
||||||
|
assert isinstance(other, SaverTest)
|
||||||
|
assert self.path == other.path
|
||||||
|
self._count += other._count
|
||||||
|
|
||||||
|
# test add
|
||||||
|
savers = SaverCollection(SaverTest('123'))
|
||||||
|
savers.add(SaverTest('123'))
|
||||||
|
savers.add(SaverTest('456'))
|
||||||
|
|
||||||
|
def check_collection(mul):
|
||||||
|
paths = ['123', '456']
|
||||||
|
for c in savers:
|
||||||
|
paths.remove(c.path)
|
||||||
|
if c.path == '123':
|
||||||
|
assert c._count == 2 * mul
|
||||||
|
elif c.path == '456':
|
||||||
|
assert c._count == 1 * mul
|
||||||
|
else:
|
||||||
|
assert False, "invalid path"
|
||||||
|
|
||||||
|
check_collection(1)
|
||||||
|
|
||||||
|
# test update
|
||||||
|
savers.update(savers)
|
||||||
|
check_collection(2)
|
||||||
21
rl_coach/tests/test_utils.py
Normal file
21
rl_coach/tests/test_utils.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from rl_coach import utils
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_get_checkpoint_state_default():
|
||||||
|
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext']
|
||||||
|
checkpoint_state = utils.get_checkpoint_state(files)
|
||||||
|
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
|
||||||
|
assert checkpoint_state.all_model_checkpoint_paths == [f[:-4] for f in sorted(files)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_get_checkpoint_state_custom():
|
||||||
|
files = ['prefix.4.test.ckpt.ext', 'prefix.2.test.ckpt.ext', 'prefix.3.test.ckpt.ext', 'prefix.1.test.ckpt.ext']
|
||||||
|
assert len(utils.get_checkpoint_state(files).all_model_checkpoint_paths) == 0 # doesn't match the default pattern
|
||||||
|
checkpoint_state = utils.get_checkpoint_state(files, filename_pattern=r'([0-9]+)[^0-9].*?\.ckpt')
|
||||||
|
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
|
||||||
|
assert checkpoint_state.all_model_checkpoint_paths == [f[7:-4] for f in sorted(files)]
|
||||||
|
|
||||||
@@ -19,6 +19,7 @@ import importlib.util
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
@@ -26,7 +27,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
from subprocess import Popen
|
from subprocess import Popen
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -547,3 +548,50 @@ def indent_string(string):
|
|||||||
return '\t' + string.replace('\n', '\n\t')
|
return '\t' + string.replace('\n', '\n\t')
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointState(object):
|
||||||
|
"""
|
||||||
|
Helper class for checkpoint directory information. It replicates
|
||||||
|
the CheckpointState protobuf class in tensorflow.
|
||||||
|
"""
|
||||||
|
def __init__(self, checkpoints: List[str]):
|
||||||
|
self._checkpoints = checkpoints
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_model_checkpoint_paths(self):
|
||||||
|
return self._checkpoints
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_checkpoint_path(self):
|
||||||
|
return self._checkpoints[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
|
||||||
|
for c in self._checkpoints:
|
||||||
|
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
|
||||||
|
return out_str
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self._checkpoints)
|
||||||
|
|
||||||
|
|
||||||
|
COACH_CHECKPOINT_PATTERN = r'\A([0-9]+)[^0-9].*?\.ckpt'
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
|
||||||
|
CheckpointState:
|
||||||
|
"""
|
||||||
|
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(1)) to sort
|
||||||
|
the checkpoint names and find the latest checkpoint
|
||||||
|
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
|
||||||
|
:param filename_pattern: regex pattern for checkpoint filenames
|
||||||
|
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names
|
||||||
|
"""
|
||||||
|
prog = re.compile(filename_pattern)
|
||||||
|
checkpoints = dict()
|
||||||
|
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
|
||||||
|
for name in filenames:
|
||||||
|
m = prog.search(name)
|
||||||
|
if m is not None and m.group(1) is not None:
|
||||||
|
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
|
||||||
|
checkpoints[int(m.group(1))] = full_path
|
||||||
|
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])
|
||||||
|
|||||||
Reference in New Issue
Block a user