mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
516 lines
22 KiB
Python
516 lines
22 KiB
Python
#
|
|
# Copyright (c) 2017 Intel Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import copy
|
|
from itertools import chain
|
|
from typing import List, Tuple, Union
|
|
from types import ModuleType
|
|
|
|
import numpy as np
|
|
import mxnet as mx
|
|
from mxnet import nd, sym
|
|
from mxnet.gluon import HybridBlock
|
|
from mxnet.ndarray import NDArray
|
|
from mxnet.symbol import Symbol
|
|
|
|
from rl_coach.base_parameters import NetworkParameters
|
|
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
|
from rl_coach.architectures.head_parameters import HeadParameters, PPOHeadParameters
|
|
from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadParameters, QHeadParameters
|
|
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
|
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters
|
|
from rl_coach.architectures.mxnet_components.architecture import MxnetArchitecture
|
|
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, TensorEmbedder, VectorEmbedder
|
|
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
|
|
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
|
|
from rl_coach.architectures.mxnet_components import utils
|
|
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
|
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
|
|
|
|
|
class GeneralMxnetNetwork(MxnetArchitecture):
|
|
"""
|
|
A generalized version of all possible networks implemented using mxnet.
|
|
"""
|
|
def __init__(self,
|
|
agent_parameters: AgentParameters,
|
|
spaces: SpacesDefinition,
|
|
name: str,
|
|
global_network=None,
|
|
network_is_local: bool=True,
|
|
network_is_trainable: bool=False):
|
|
"""
|
|
:param agent_parameters: the agent parameters
|
|
:param spaces: the spaces definition of the agent
|
|
:param name: the name of the network
|
|
:param global_network: the global network replica that is shared between all the workers
|
|
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
|
|
:param network_is_trainable: is the network trainable (we can apply gradients on it)
|
|
"""
|
|
self.network_wrapper_name = name.split('/')[0]
|
|
self.network_parameters = agent_parameters.network_wrappers[self.network_wrapper_name]
|
|
if self.network_parameters.use_separate_networks_per_head:
|
|
self.num_heads_per_network = 1
|
|
self.num_networks = len(self.network_parameters.heads_parameters)
|
|
else:
|
|
self.num_heads_per_network = len(self.network_parameters.heads_parameters)
|
|
self.num_networks = 1
|
|
|
|
super().__init__(agent_parameters, spaces, name, global_network,
|
|
network_is_local, network_is_trainable)
|
|
|
|
def construct_model(self):
|
|
# validate the configuration
|
|
if len(self.network_parameters.input_embedders_parameters) == 0:
|
|
raise ValueError("At least one input type should be defined")
|
|
|
|
if len(self.network_parameters.heads_parameters) == 0:
|
|
raise ValueError("At least one output type should be defined")
|
|
|
|
if self.network_parameters.middleware_parameters is None:
|
|
raise ValueError("Exactly one middleware type should be defined")
|
|
|
|
self.model = GeneralModel(
|
|
num_networks=self.num_networks,
|
|
num_heads_per_network=self.num_heads_per_network,
|
|
network_is_local=self.network_is_local,
|
|
network_name=self.network_wrapper_name,
|
|
agent_parameters=self.ap,
|
|
network_parameters=self.network_parameters,
|
|
spaces=self.spaces)
|
|
|
|
self.losses = self.model.losses()
|
|
|
|
# Learning rate
|
|
lr_scheduler = None
|
|
if self.network_parameters.learning_rate_decay_rate != 0:
|
|
lr_scheduler = mx.lr_scheduler.FactorScheduler(
|
|
step=self.network_parameters.learning_rate_decay_steps,
|
|
factor=self.network_parameters.learning_rate_decay_rate)
|
|
|
|
# Optimizer
|
|
# FIXME Does this code for distributed training make sense?
|
|
if self.distributed_training and self.network_is_local and self.network_parameters.shared_optimizer:
|
|
# distributed training + is a local network + optimizer shared -> take the global optimizer
|
|
self.optimizer = self.global_network.optimizer
|
|
elif (self.distributed_training and self.network_is_local and not self.network_parameters.shared_optimizer)\
|
|
or self.network_parameters.shared_optimizer or not self.distributed_training:
|
|
|
|
if self.network_parameters.optimizer_type == 'Adam':
|
|
self.optimizer = mx.optimizer.Adam(
|
|
learning_rate=self.network_parameters.learning_rate,
|
|
beta1=self.network_parameters.adam_optimizer_beta1,
|
|
beta2=self.network_parameters.adam_optimizer_beta2,
|
|
epsilon=self.network_parameters.optimizer_epsilon,
|
|
lr_scheduler=lr_scheduler)
|
|
elif self.network_parameters.optimizer_type == 'RMSProp':
|
|
self.optimizer = mx.optimizer.RMSProp(
|
|
learning_rate=self.network_parameters.learning_rate,
|
|
gamma1=self.network_parameters.rms_prop_optimizer_decay,
|
|
epsilon=self.network_parameters.optimizer_epsilon,
|
|
lr_scheduler=lr_scheduler)
|
|
elif self.network_parameters.optimizer_type == 'LBFGS':
|
|
raise NotImplementedError('LBFGS optimizer not implemented')
|
|
else:
|
|
raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type))
|
|
|
|
@property
|
|
def output_heads(self):
|
|
return self.model.output_heads
|
|
|
|
|
|
def _get_activation(activation_function_string: str):
|
|
"""
|
|
Map the activation function from a string to the mxnet framework equivalent
|
|
:param activation_function_string: the type of the activation function
|
|
:return: mxnet activation function string
|
|
"""
|
|
return utils.get_mxnet_activation_name(activation_function_string)
|
|
|
|
|
|
def _sanitize_activation(params: Union[InputEmbedderParameters, MiddlewareParameters, HeadParameters]) ->\
|
|
Union[InputEmbedderParameters, MiddlewareParameters, HeadParameters]:
|
|
"""
|
|
Change activation function to the mxnet specific value
|
|
:param params: any parameter that has activation_function property
|
|
:return: copy of params with activation function correctly set
|
|
"""
|
|
params_copy = copy.copy(params)
|
|
params_copy.activation_function = _get_activation(params.activation_function)
|
|
return params_copy
|
|
|
|
|
|
def _get_input_embedder(spaces: SpacesDefinition,
|
|
input_name: str,
|
|
embedder_params: InputEmbedderParameters) -> ModuleType:
|
|
"""
|
|
Given an input embedder parameters class, creates the input embedder and returns it
|
|
:param input_name: the name of the input to the embedder (used for retrieving the shape). The input should
|
|
be a value within the state or the action.
|
|
:param embedder_params: the parameters of the class of the embedder
|
|
:return: the embedder instance
|
|
"""
|
|
allowed_inputs = copy.copy(spaces.state.sub_spaces)
|
|
allowed_inputs["action"] = copy.copy(spaces.action)
|
|
allowed_inputs["goal"] = copy.copy(spaces.goal)
|
|
|
|
if input_name not in allowed_inputs.keys():
|
|
raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}"
|
|
.format(input_name, allowed_inputs.keys()))
|
|
|
|
type = "vector"
|
|
if isinstance(allowed_inputs[input_name], TensorObservationSpace):
|
|
type = "tensor"
|
|
elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
|
type = "image"
|
|
|
|
def sanitize_params(params: InputEmbedderParameters):
|
|
params_copy = _sanitize_activation(params)
|
|
# params_copy.input_rescaling = params_copy.input_rescaling[type]
|
|
# params_copy.input_offset = params_copy.input_offset[type]
|
|
params_copy.name = input_name
|
|
return params_copy
|
|
|
|
embedder_params = sanitize_params(embedder_params)
|
|
if type == 'vector':
|
|
module = VectorEmbedder(embedder_params)
|
|
elif type == 'image':
|
|
module = ImageEmbedder(embedder_params)
|
|
elif type == 'tensor':
|
|
module = TensorEmbedder(embedder_params)
|
|
else:
|
|
raise KeyError('Unsupported embedder type: {}'.format(type))
|
|
return module
|
|
|
|
|
|
def _get_middleware(middleware_params: MiddlewareParameters) -> ModuleType:
|
|
"""
|
|
Given a middleware type, creates the middleware and returns it
|
|
:param middleware_params: the paramaeters of the middleware class
|
|
:return: the middleware instance
|
|
"""
|
|
middleware_params = _sanitize_activation(middleware_params)
|
|
if isinstance(middleware_params, FCMiddlewareParameters):
|
|
module = FCMiddleware(middleware_params)
|
|
elif isinstance(middleware_params, LSTMMiddlewareParameters):
|
|
module = LSTMMiddleware(middleware_params)
|
|
else:
|
|
raise KeyError('Unsupported middleware type: {}'.format(type(middleware_params)))
|
|
|
|
return module
|
|
|
|
|
|
def _get_output_head(
|
|
head_params: HeadParameters,
|
|
head_idx: int,
|
|
head_type_index: int,
|
|
agent_params: AgentParameters,
|
|
spaces: SpacesDefinition,
|
|
network_name: str,
|
|
is_local: bool) -> Head:
|
|
"""
|
|
Given a head type, creates the head and returns it
|
|
:param head_params: the parameters of the head to create
|
|
:param head_idx: the head index
|
|
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
|
|
:param agent_params: agent parameters
|
|
:param spaces: state and action space definitions
|
|
:param network_name: name of the network
|
|
:param is_local:
|
|
:return: head block
|
|
"""
|
|
head_params = _sanitize_activation(head_params)
|
|
if isinstance(head_params, PPOHeadParameters):
|
|
module = PPOHead(
|
|
agent_parameters=agent_params,
|
|
spaces=spaces,
|
|
network_name=network_name,
|
|
head_type_idx=head_type_index,
|
|
loss_weight=head_params.loss_weight,
|
|
is_local=is_local,
|
|
activation_function=head_params.activation_function,
|
|
dense_layer=head_params.dense_layer)
|
|
elif isinstance(head_params, VHeadParameters):
|
|
module = VHead(
|
|
agent_parameters=agent_params,
|
|
spaces=spaces,
|
|
network_name=network_name,
|
|
head_type_idx=head_type_index,
|
|
loss_weight=head_params.loss_weight,
|
|
is_local=is_local,
|
|
activation_function=head_params.activation_function,
|
|
dense_layer=head_params.dense_layer)
|
|
elif isinstance(head_params, PPOVHeadParameters):
|
|
module = PPOVHead(
|
|
agent_parameters=agent_params,
|
|
spaces=spaces,
|
|
network_name=network_name,
|
|
head_type_idx=head_type_index,
|
|
loss_weight=head_params.loss_weight,
|
|
is_local=is_local,
|
|
activation_function=head_params.activation_function,
|
|
dense_layer=head_params.dense_layer)
|
|
elif isinstance(head_params, QHeadParameters):
|
|
module = QHead(
|
|
agent_parameters=agent_params,
|
|
spaces=spaces,
|
|
network_name=network_name,
|
|
head_type_idx=head_type_index,
|
|
loss_weight=head_params.loss_weight,
|
|
is_local=is_local,
|
|
activation_function=head_params.activation_function,
|
|
dense_layer=head_params.dense_layer)
|
|
else:
|
|
raise KeyError('Unsupported head type: {}'.format(type(head_params)))
|
|
|
|
return module
|
|
|
|
|
|
class ScaledGradHead(HybridBlock, utils.OnnxHandlerBlock):
|
|
"""
|
|
Wrapper block for applying gradient scaling to input before feeding the head network
|
|
"""
|
|
def __init__(self,
|
|
head_index: int,
|
|
head_type_index: int,
|
|
network_name: str,
|
|
spaces: SpacesDefinition,
|
|
network_is_local: bool,
|
|
agent_params: AgentParameters,
|
|
head_params: HeadParameters) -> None:
|
|
"""
|
|
:param head_index: the head index
|
|
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
|
|
:param network_name: name of the network
|
|
:param spaces: state and action space definitions
|
|
:param network_is_local: whether network is local
|
|
:param agent_params: agent parameters
|
|
:param head_params: head parameters
|
|
"""
|
|
super(ScaledGradHead, self).__init__()
|
|
utils.OnnxHandlerBlock.__init__(self)
|
|
|
|
head_params = _sanitize_activation(head_params)
|
|
with self.name_scope():
|
|
self.head = _get_output_head(
|
|
head_params=head_params,
|
|
head_idx=head_index,
|
|
head_type_index=head_type_index,
|
|
agent_params=agent_params,
|
|
spaces=spaces,
|
|
network_name=network_name,
|
|
is_local=network_is_local)
|
|
self.gradient_rescaler = self.params.get_constant(
|
|
name='gradient_rescaler',
|
|
value=np.array([float(head_params.rescale_gradient_from_head_by_factor)]))
|
|
# self.gradient_rescaler = self.params.get(
|
|
# name='gradient_rescaler',
|
|
# shape=(1,),
|
|
# init=mx.init.Constant(float(head_params.rescale_gradient_from_head_by_factor)))
|
|
|
|
def hybrid_forward(self,
|
|
F: ModuleType,
|
|
x: Union[NDArray, Symbol],
|
|
gradient_rescaler: Union[NDArray, Symbol]) -> Tuple[Union[NDArray, Symbol], ...]:
|
|
""" Overrides gluon.HybridBlock.hybrid_forward
|
|
:param nd or sym F: ndarray or symbol module
|
|
:param x: head input
|
|
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
|
|
:return: head output
|
|
"""
|
|
if self._onnx:
|
|
# ONNX doesn't support BlockGrad() operator, but it's not typically needed for
|
|
# ONNX because mostly forward calls are performed using ONNX exported network.
|
|
grad_scaled_x = x
|
|
else:
|
|
grad_scaled_x = (F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) +
|
|
F.broadcast_mul(gradient_rescaler, x))
|
|
out = self.head(grad_scaled_x)
|
|
return out
|
|
|
|
|
|
class SingleModel(HybridBlock):
|
|
"""
|
|
Block that connects a single embedder, with middleware and one to multiple heads
|
|
"""
|
|
def __init__(self,
|
|
network_is_local: bool,
|
|
network_name: str,
|
|
agent_parameters: AgentParameters,
|
|
in_emb_param_dict: {str: InputEmbedderParameters},
|
|
embedding_merger_type: EmbeddingMergerType,
|
|
middleware_param: MiddlewareParameters,
|
|
head_param_list: [HeadParameters],
|
|
head_type_idx_start: int,
|
|
spaces: SpacesDefinition,
|
|
*args, **kwargs):
|
|
"""
|
|
:param network_is_local: True if network is local
|
|
:param network_name: name of the network
|
|
:param agent_parameters: agent parameters
|
|
:param in_emb_param_dict: dictionary of embedder name to embedding parameters
|
|
:param embedding_merger_type: type of merging output of embedders: concatenate or sum
|
|
:param middleware_param: middleware parameters
|
|
:param head_param_list: list of head parameters, one per head type
|
|
:param head_type_idx_start: start index for head type index counting
|
|
:param spaces: state and action space definition
|
|
"""
|
|
super(SingleModel, self).__init__(*args, **kwargs)
|
|
|
|
self._embedding_merger_type = embedding_merger_type
|
|
self._input_embedders = list() # type: List[HybridBlock]
|
|
self._output_heads = list() # type: List[ScaledGradHead]
|
|
|
|
with self.name_scope():
|
|
for input_name in sorted(in_emb_param_dict):
|
|
input_type = in_emb_param_dict[input_name]
|
|
input_embedder = _get_input_embedder(spaces, input_name, input_type)
|
|
self.register_child(input_embedder)
|
|
self._input_embedders.append(input_embedder)
|
|
|
|
self.middleware = _get_middleware(middleware_param)
|
|
|
|
for i, head_param in enumerate(head_param_list):
|
|
for head_copy_idx in range(head_param.num_output_head_copies):
|
|
# create output head and add it to the output heads list
|
|
output_head = ScaledGradHead(
|
|
head_index=(head_type_idx_start + i) * head_param.num_output_head_copies + head_copy_idx,
|
|
head_type_index=head_type_idx_start + i,
|
|
network_name=network_name,
|
|
spaces=spaces,
|
|
network_is_local=network_is_local,
|
|
agent_params=agent_parameters,
|
|
head_params=head_param)
|
|
self.register_child(output_head)
|
|
self._output_heads.append(output_head)
|
|
|
|
def hybrid_forward(self, F, *inputs: Union[NDArray, Symbol]) -> Tuple[Union[NDArray, Symbol], ...]:
|
|
""" Overrides gluon.HybridBlock.hybrid_forward
|
|
:param nd or sym F: ndarray or symbol block
|
|
:param inputs: model inputs, one for each embedder
|
|
:return: head outputs in a tuple
|
|
"""
|
|
# Input Embeddings
|
|
state_embedding = list()
|
|
for input, embedder in zip(inputs, self._input_embedders):
|
|
state_embedding.append(embedder(input))
|
|
|
|
# Merger
|
|
if len(state_embedding) == 1:
|
|
state_embedding = state_embedding[0]
|
|
else:
|
|
if self._embedding_merger_type == EmbeddingMergerType.Concat:
|
|
state_embedding = F.concat(*state_embedding, dim=1, name='merger') # NC or NCHW layout
|
|
elif self._embedding_merger_type == EmbeddingMergerType.Sum:
|
|
state_embedding = F.add_n(*state_embedding, name='merger')
|
|
|
|
# Middleware
|
|
state_embedding = self.middleware(state_embedding)
|
|
|
|
# Head
|
|
outputs = tuple()
|
|
for head in self._output_heads:
|
|
out = head(state_embedding)
|
|
if not isinstance(out, tuple):
|
|
out = (out,)
|
|
outputs += out
|
|
|
|
return outputs
|
|
|
|
@property
|
|
def input_embedders(self) -> List[HybridBlock]:
|
|
"""
|
|
:return: list of input embedders
|
|
"""
|
|
return self._input_embedders
|
|
|
|
@property
|
|
def output_heads(self) -> List[Head]:
|
|
"""
|
|
:return: list of output heads
|
|
"""
|
|
return [h.head for h in self._output_heads]
|
|
|
|
|
|
class GeneralModel(HybridBlock):
|
|
"""
|
|
Block that creates multiple single models
|
|
"""
|
|
def __init__(self,
|
|
num_networks: int,
|
|
num_heads_per_network: int,
|
|
network_is_local: bool,
|
|
network_name: str,
|
|
agent_parameters: AgentParameters,
|
|
network_parameters: NetworkParameters,
|
|
spaces: SpacesDefinition,
|
|
*args, **kwargs):
|
|
"""
|
|
:param num_networks: number of networks to create
|
|
:param num_heads_per_network: number of heads per network to create
|
|
:param network_is_local: True if network is local
|
|
:param network_name: name of the network
|
|
:param agent_parameters: agent parameters
|
|
:param network_parameters: network parameters
|
|
:param spaces: state and action space definitions
|
|
"""
|
|
super(GeneralModel, self).__init__(*args, **kwargs)
|
|
|
|
with self.name_scope():
|
|
self.nets = list()
|
|
for network_idx in range(num_networks):
|
|
head_type_idx_start = network_idx * num_heads_per_network
|
|
head_type_idx_end = head_type_idx_start + num_heads_per_network
|
|
net = SingleModel(
|
|
head_type_idx_start=head_type_idx_start,
|
|
network_name=network_name,
|
|
network_is_local=network_is_local,
|
|
agent_parameters=agent_parameters,
|
|
in_emb_param_dict=network_parameters.input_embedders_parameters,
|
|
embedding_merger_type=network_parameters.embedding_merger_type,
|
|
middleware_param=network_parameters.middleware_parameters,
|
|
head_param_list=network_parameters.heads_parameters[head_type_idx_start:head_type_idx_end],
|
|
spaces=spaces)
|
|
self.register_child(net)
|
|
self.nets.append(net)
|
|
|
|
def hybrid_forward(self, F, *inputs):
|
|
""" Overrides gluon.HybridBlock.hybrid_forward
|
|
:param nd or sym F: ndarray or symbol block
|
|
:param inputs: model inputs, one for each embedder. Passed to all networks.
|
|
:return: head outputs in a tuple
|
|
"""
|
|
outputs = tuple()
|
|
for net in self.nets:
|
|
out = net(*inputs)
|
|
outputs += out
|
|
return outputs
|
|
|
|
@property
|
|
def output_heads(self) -> List[Head]:
|
|
""" Return all heads in a single list
|
|
Note: There is a one-to-one mapping between output_heads and losses
|
|
:return: list of heads
|
|
"""
|
|
return list(chain.from_iterable(net.output_heads for net in self.nets))
|
|
|
|
def losses(self) -> List[HeadLoss]:
|
|
""" Construct loss blocks for network training
|
|
Note: There is a one-to-one mapping between output_heads and losses
|
|
:return: list of loss blocks
|
|
"""
|
|
return [h.loss() for net in self.nets for h in net.output_heads]
|