mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
439 lines
21 KiB
Python
439 lines
21 KiB
Python
#
|
|
# Copyright (c) 2017 Intel Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import copy
|
|
from typing import Any, Dict, Generator, List, Tuple, Union
|
|
|
|
import numpy as np
|
|
import mxnet as mx
|
|
from mxnet import autograd, gluon, nd
|
|
from mxnet.ndarray import NDArray
|
|
|
|
from rl_coach.architectures.architecture import Architecture
|
|
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
|
from rl_coach.architectures.mxnet_components import utils
|
|
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
|
|
from rl_coach.base_parameters import AgentParameters
|
|
from rl_coach.saver import SaverCollection
|
|
from rl_coach.spaces import SpacesDefinition
|
|
from rl_coach.utils import force_list, squeeze_list
|
|
|
|
|
|
class MxnetArchitecture(Architecture):
|
|
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, devices: List[mx.Context],
|
|
name: str = "", global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
|
|
"""
|
|
:param agent_parameters: the agent parameters
|
|
:param spaces: the spaces definition of the agent
|
|
:param name: the name of the network
|
|
:param global_network: the global network replica that is shared between all the workers
|
|
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
|
|
:param network_is_trainable: is the network trainable (we can apply gradients on it)
|
|
"""
|
|
super().__init__(agent_parameters, spaces, name)
|
|
self.middleware = None
|
|
self.network_is_local = network_is_local
|
|
self.global_network = global_network
|
|
if not self.network_parameters.tensorflow_support:
|
|
raise ValueError('TensorFlow is not supported for this agent')
|
|
self.losses = [] # type: List[HeadLoss]
|
|
self.shared_accumulated_gradients = []
|
|
self.curr_rnn_c_in = None
|
|
self.curr_rnn_h_in = None
|
|
self.gradients_wrt_inputs = []
|
|
self.train_writer = None
|
|
self.accumulated_gradients = None
|
|
self.network_is_trainable = network_is_trainable
|
|
self.is_training = False
|
|
self.model = None # type: GeneralModel
|
|
self._devices = devices
|
|
|
|
self.is_chief = self.ap.task_parameters.task_index == 0
|
|
self.network_is_global = not self.network_is_local and global_network is None
|
|
self.distributed_training = self.network_is_global or self.network_is_local and global_network is not None
|
|
|
|
self.optimizer_type = self.network_parameters.optimizer_type
|
|
if self.ap.task_parameters.seed is not None:
|
|
mx.random.seed(self.ap.task_parameters.seed)
|
|
|
|
# Call to child class to create the model
|
|
self.construct_model()
|
|
|
|
self.trainer = None # type: gluon.Trainer
|
|
|
|
def __str__(self):
|
|
return self.model.summary(*self._dummy_model_inputs())
|
|
|
|
def _model_grads(self, index: int=0) ->\
|
|
Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]:
|
|
"""
|
|
Creates a copy of model gradients and returns them in a list, in the same order as collect_params()
|
|
:param index: device index. Set to -1 to get a tuple of list of NDArrays for all devices
|
|
:return: a generator for model gradient values
|
|
"""
|
|
if index < 0:
|
|
return (p.list_grad() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
|
else:
|
|
return (p.list_grad()[index] for p in self.model.collect_params().values() if p.grad_req != 'null')
|
|
|
|
def _model_input_shapes(self) -> List[List[int]]:
|
|
"""
|
|
Create a list of input array shapes
|
|
:return: type of input shapes
|
|
"""
|
|
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
|
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
|
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
|
embedders = self.model.nets[0].input_embedders
|
|
return list([1] + allowed_inputs[emb.embedder_name].shape.tolist() for emb in embedders)
|
|
|
|
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
|
|
"""
|
|
Creates a tuple of input arrays with correct shapes that can be used for shape inference
|
|
of the model weights and for printing the summary
|
|
:return: tuple of inputs for model forward pass
|
|
"""
|
|
input_shapes = self._model_input_shapes()
|
|
inputs = tuple(nd.zeros(tuple(shape), ctx=self._devices[0]) for shape in input_shapes)
|
|
return inputs
|
|
|
|
def construct_model(self) -> None:
|
|
"""
|
|
Construct network model. Implemented by child class.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def set_session(self, sess) -> None:
|
|
"""
|
|
Initializes the model parameters and creates the model trainer.
|
|
NOTEL Session for mxnet backend must be None.
|
|
:param sess: must be None
|
|
"""
|
|
assert sess is None
|
|
# FIXME Add initializer
|
|
self.model.collect_params().initialize(ctx=self._devices)
|
|
# Hybridize model and losses
|
|
self.model.hybridize()
|
|
for l in self.losses:
|
|
l.hybridize()
|
|
|
|
# Pass dummy data with correct shape to trigger shape inference and full parameter initialization
|
|
self.model(*self._dummy_model_inputs())
|
|
|
|
if self.network_is_trainable:
|
|
self.trainer = gluon.Trainer(
|
|
self.model.collect_params(), optimizer=self.optimizer, update_on_kvstore=False)
|
|
|
|
def reset_accumulated_gradients(self) -> None:
|
|
"""
|
|
Reset model gradients as well as accumulated gradients to zero. If accumulated gradients
|
|
have not been created yet, it constructs them on CPU.
|
|
"""
|
|
# Set model gradients to zero
|
|
for p in self.model.collect_params().values():
|
|
p.zero_grad()
|
|
# Set accumulated gradients to zero if already initialized, otherwise create a copy
|
|
if self.accumulated_gradients:
|
|
for a in self.accumulated_gradients:
|
|
a *= 0
|
|
else:
|
|
self.accumulated_gradients = [g.copy() for g in self._model_grads()]
|
|
|
|
def accumulate_gradients(self,
|
|
inputs: Dict[str, np.ndarray],
|
|
targets: List[np.ndarray],
|
|
additional_fetches: List[Tuple[int, str]] = None,
|
|
importance_weights: np.ndarray = None,
|
|
no_accumulation: bool = False) -> Tuple[float, List[float], float, list]:
|
|
"""
|
|
Runs a forward & backward pass, clips gradients if needed and accumulates them into the accumulation
|
|
:param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray
|
|
is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size)
|
|
:param targets: targets required by loss (e.g. sum of discounted rewards)
|
|
:param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str)
|
|
tuple of head-type-index and fetch-name. The tuple is obtained from each head.
|
|
:param importance_weights: ndarray of shape (batch_size,) to multiply with batch loss.
|
|
:param no_accumulation: if True, set gradient values to the new gradients, otherwise sum with previously
|
|
calculated gradients
|
|
:return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors
|
|
total_loss (float): sum of all head losses
|
|
losses (list of float): list of all losses. The order is list of target losses followed by list of
|
|
regularization losses. The specifics of losses is dependant on the network parameters
|
|
(number of heads, etc.)
|
|
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
|
|
fetched_tensors: all values for additional_fetches
|
|
"""
|
|
if self.accumulated_gradients is None:
|
|
self.reset_accumulated_gradients()
|
|
|
|
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
|
nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
|
|
|
|
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported"
|
|
|
|
targets = force_list(targets)
|
|
with autograd.record():
|
|
out_per_head = utils.split_outputs_per_head(self.model(*nd_inputs), self.model.output_heads)
|
|
tgt_per_loss = utils.split_targets_per_loss(targets, self.losses)
|
|
|
|
losses = list()
|
|
regularizations = list()
|
|
additional_fetches = [(k, None) for k in additional_fetches]
|
|
for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss):
|
|
l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss)
|
|
# Align arguments with loss.loss_forward and convert to NDArray
|
|
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss), h_out[0].context)
|
|
# Calculate loss and all auxiliary outputs
|
|
loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema)
|
|
if LOSS_OUT_TYPE_LOSS in loss_outputs:
|
|
losses.extend(loss_outputs[LOSS_OUT_TYPE_LOSS])
|
|
if LOSS_OUT_TYPE_REGULARIZATION in loss_outputs:
|
|
regularizations.extend(loss_outputs[LOSS_OUT_TYPE_REGULARIZATION])
|
|
# Set additional fetches
|
|
for i, fetch in enumerate(additional_fetches):
|
|
head_type_idx, fetch_name = fetch[0] # fetch key is a tuple of (head_type_index, fetch_name)
|
|
if head_type_idx == h.head_type_idx:
|
|
assert fetch[1] is None # sanity check that fetch is None
|
|
additional_fetches[i] = (fetch[0], loss_outputs[fetch_name])
|
|
|
|
# Total loss is losses and regularization (NOTE: order is important)
|
|
total_loss_list = losses + regularizations
|
|
total_loss = nd.add_n(*total_loss_list)
|
|
|
|
# Calculate gradients
|
|
total_loss.backward()
|
|
|
|
assert self.optimizer_type != 'LBFGS', 'LBFGS not supported'
|
|
|
|
# allreduce gradients from all contexts
|
|
self.trainer.allreduce_grads()
|
|
|
|
model_grads_cpy = [g.copy() for g in self._model_grads()]
|
|
# Calculate global norm of gradients
|
|
# FIXME global norm is returned even when not used for clipping! Is this necessary?
|
|
# FIXME global norm might be calculated twice if clipping method is global norm
|
|
norm_unclipped_grads = utils.global_norm(model_grads_cpy)
|
|
|
|
# Clip gradients
|
|
if self.network_parameters.clip_gradients:
|
|
utils.clip_grad(
|
|
model_grads_cpy,
|
|
clip_method=self.network_parameters.gradients_clipping_method,
|
|
clip_val=self.network_parameters.clip_gradients,
|
|
inplace=True)
|
|
|
|
# Update self.accumulated_gradients depending on no_accumulation flag
|
|
if no_accumulation:
|
|
for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
|
|
acc_grad[:] = model_grad
|
|
else:
|
|
for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
|
|
acc_grad += model_grad
|
|
|
|
# result of of additional fetches
|
|
fetched_tensors = [fetch[1] for fetch in additional_fetches]
|
|
|
|
# convert everything to numpy or scalar before returning
|
|
result = utils.asnumpy_or_asscalar((total_loss, total_loss_list, norm_unclipped_grads, fetched_tensors))
|
|
return result
|
|
|
|
def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
|
"""
|
|
Applies the given gradients to the network weights and resets accumulated gradients to zero
|
|
:param gradients: The gradients to use for the update
|
|
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
|
"""
|
|
self.apply_gradients(gradients, scaler)
|
|
self.reset_accumulated_gradients()
|
|
|
|
def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
|
"""
|
|
Applies the given gradients to the network weights
|
|
:param gradients: The gradients to use for the update
|
|
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
|
|
The gradients will be MULTIPLIED by this factor
|
|
"""
|
|
assert self.optimizer_type != 'LBFGS'
|
|
|
|
batch_size = 1
|
|
if self.distributed_training and not self.network_parameters.async_training:
|
|
# rescale the gradients so that they average out with the gradients from the other workers
|
|
if self.network_parameters.scale_down_gradients_by_number_of_workers_for_sync_training:
|
|
batch_size = self.ap.task_parameters.num_training_tasks
|
|
|
|
# set parameter gradients to gradients passed in
|
|
for param_grad, gradient in zip(self._model_grads(-1), gradients):
|
|
for pg in param_grad:
|
|
pg[:] = gradient
|
|
# update gradients
|
|
self.trainer.update(batch_size=batch_size)
|
|
|
|
def _predict(self, inputs: Dict[str, np.ndarray]) -> Tuple[NDArray, ...]:
|
|
"""
|
|
Run a forward pass of the network using the given input
|
|
:param inputs: The input dictionary for the network. Key is name of the embedder.
|
|
:return: The network output
|
|
|
|
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
|
"""
|
|
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
|
nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
|
|
|
|
assert self.middleware.__class__.__name__ != 'LSTMMiddleware'
|
|
|
|
output = self.model(*nd_inputs)
|
|
return output
|
|
|
|
def predict(self,
|
|
inputs: Dict[str, np.ndarray],
|
|
outputs: List[str]=None,
|
|
squeeze_output: bool=True,
|
|
initial_feed_dict: Dict[str, np.ndarray]=None) -> Tuple[np.ndarray, ...]:
|
|
"""
|
|
Run a forward pass of the network using the given input
|
|
:param inputs: The input dictionary for the network. Key is name of the embedder.
|
|
:param outputs: list of outputs to return. Return all outputs if unspecified (currently not supported)
|
|
:param squeeze_output: call squeeze_list on output if True
|
|
:param initial_feed_dict: a dictionary of extra inputs for forward pass (currently not supported)
|
|
:return: The network output
|
|
|
|
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
|
"""
|
|
assert initial_feed_dict is None, "initial_feed_dict must be None"
|
|
assert outputs is None, "outputs must be None"
|
|
|
|
output = self._predict(inputs)
|
|
output = list(o.asnumpy() for o in output)
|
|
if squeeze_output:
|
|
output = squeeze_list(output)
|
|
return output
|
|
|
|
@staticmethod
|
|
def parallel_predict(sess: Any,
|
|
network_input_tuples: List[Tuple['MxnetArchitecture', Dict[str, np.ndarray]]]) -> \
|
|
Tuple[np.ndarray, ...]:
|
|
"""
|
|
:param sess: active session to use for prediction (must be None for MXNet)
|
|
:param network_input_tuples: tuple of network and corresponding input
|
|
:return: tuple of outputs from all networks
|
|
"""
|
|
assert sess is None
|
|
output = list()
|
|
for net, inputs in network_input_tuples:
|
|
output += net._predict(inputs)
|
|
return tuple(o.asnumpy() for o in output)
|
|
|
|
def train_on_batch(self,
|
|
inputs: Dict[str, np.ndarray],
|
|
targets: List[np.ndarray],
|
|
scaler: float = 1.,
|
|
additional_fetches: list = None,
|
|
importance_weights: np.ndarray = None) -> Tuple[float, List[float], float, list]:
|
|
"""
|
|
Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a
|
|
forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to
|
|
update the weights.
|
|
:param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray
|
|
is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size)
|
|
:param targets: targets required by loss (e.g. sum of discounted rewards)
|
|
:param scaler: value to scale gradients by before optimizing network weights
|
|
:param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str)
|
|
tuple of head-type-index and fetch-name. The tuple is obtained from each head.
|
|
:param importance_weights: ndarray of shape (batch_size,) to multiply with batch loss.
|
|
:return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors
|
|
total_loss (float): sum of all head losses
|
|
losses (list of float): list of all losses. The order is list of target losses followed by list
|
|
of regularization losses. The specifics of losses is dependant on the network parameters
|
|
(number of heads, etc.)
|
|
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
|
|
fetched_tensors: all values for additional_fetches
|
|
"""
|
|
loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
|
|
importance_weights=importance_weights)
|
|
self.apply_and_reset_gradients(self.accumulated_gradients, scaler)
|
|
return loss
|
|
|
|
def get_weights(self) -> gluon.ParameterDict:
|
|
"""
|
|
:return: a ParameterDict containing all network weights
|
|
"""
|
|
return self.model.collect_params()
|
|
|
|
def set_weights(self, weights: gluon.ParameterDict, new_rate: float=1.0) -> None:
|
|
"""
|
|
Sets the network weights from the given ParameterDict
|
|
:param new_rate: ratio for adding new and old weight values: val=rate*weights + (1-rate)*old_weights
|
|
"""
|
|
old_weights = self.model.collect_params()
|
|
for name, p in weights.items():
|
|
name = name[len(weights.prefix):] # Strip prefix
|
|
old_p = old_weights[old_weights.prefix + name] # Add prefix
|
|
old_p.set_data(new_rate * p._reduce() + (1 - new_rate) * old_p._reduce())
|
|
|
|
def get_variable_value(self, variable: Union[gluon.Parameter, NDArray]) -> np.ndarray:
|
|
"""
|
|
Get the value of a variable
|
|
:param variable: the variable
|
|
:return: the value of the variable
|
|
"""
|
|
if isinstance(variable, gluon.Parameter):
|
|
variable = variable._reduce().asnumpy()
|
|
if isinstance(variable, NDArray):
|
|
return variable.asnumpy()
|
|
return variable
|
|
|
|
def set_variable_value(self, assign_op: callable, value: Any, placeholder=None) -> None:
|
|
"""
|
|
Updates value of a variable.
|
|
:param assign_op: a callable assign function for setting the variable
|
|
:param value: a value to set the variable to
|
|
:param placeholder: unused (placeholder in symbolic framework backends)
|
|
"""
|
|
assert callable(assign_op)
|
|
assign_op(value)
|
|
|
|
def set_is_training(self, state: bool) -> None:
|
|
"""
|
|
Set the phase of the network between training and testing
|
|
:param state: The current state (True = Training, False = Testing)
|
|
:return: None
|
|
"""
|
|
self.is_training = state
|
|
|
|
def reset_internal_memory(self) -> None:
|
|
"""
|
|
Reset any internal memory used by the network. For example, an LSTM internal state
|
|
:return: None
|
|
"""
|
|
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
|
|
|
|
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
|
"""
|
|
Collection of all checkpoints for the network (typically only one checkpoint)
|
|
:param parent_path_suffix: path suffix of the parent of the network
|
|
(e.g. could be name of level manager plus name of agent)
|
|
:return: checkpoint collection for the network
|
|
"""
|
|
name = self.name.replace('/', '.')
|
|
savers = SaverCollection(ParameterDictSaver(
|
|
name="{}.{}".format(parent_path_suffix, name),
|
|
param_dict=self.model.collect_params()))
|
|
if self.ap.task_parameters.export_onnx_graph:
|
|
savers.add(OnnxSaver(
|
|
name="{}.{}.onnx".format(parent_path_suffix, name),
|
|
model=self.model,
|
|
input_shapes=self._model_input_shapes()))
|
|
return savers
|