mirror of
https://github.com/gryf/coach.git
synced 2026-03-04 15:55:47 +01:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
@@ -41,22 +41,41 @@ class Architecture(object):
|
||||
self.optimizer = None
|
||||
self.ap = agent_parameters
|
||||
|
||||
def predict(self, inputs: Dict[str, np.ndarray]) -> List[np.ndarray]:
|
||||
def predict(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
outputs: List[Any] = None,
|
||||
squeeze_output: bool = True,
|
||||
initial_feed_dict: Dict[Any, np.ndarray] = None) -> Tuple[np.ndarray, ...]:
|
||||
"""
|
||||
Given input observations, use the model to make predictions (e.g. action or value).
|
||||
|
||||
:param inputs: current state (i.e. observations, measurements, goals, etc.)
|
||||
(e.g. `{'observation': numpy.ndarray}` of shape (batch_size, observation_space_size))
|
||||
:param outputs: list of outputs to return. Return all outputs if unspecified. Type of the list elements
|
||||
depends on the framework backend.
|
||||
:param squeeze_output: call squeeze_list on output before returning if True
|
||||
:param initial_feed_dict: a dictionary of extra inputs for forward pass.
|
||||
:return: predictions of action or value of shape (batch_size, action_space_size) for action predictions)
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def parallel_predict(sess: Any,
|
||||
network_input_tuples: List[Tuple['Architecture', Dict[str, np.ndarray]]]) -> \
|
||||
Tuple[np.ndarray, ...]:
|
||||
"""
|
||||
:param sess: active session to use for prediction
|
||||
:param network_input_tuples: tuple of network and corresponding input
|
||||
:return: list or tuple of outputs from all networks
|
||||
"""
|
||||
pass
|
||||
|
||||
def train_on_batch(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
targets: List[np.ndarray],
|
||||
scaler: float=1.,
|
||||
additional_fetches: list=None,
|
||||
importance_weights: np.ndarray=None) -> tuple:
|
||||
importance_weights: np.ndarray=None) -> Tuple[float, List[float], float, list]:
|
||||
"""
|
||||
Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a
|
||||
forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to
|
||||
@@ -118,8 +137,7 @@ class Architecture(object):
|
||||
targets: List[np.ndarray],
|
||||
additional_fetches: list=None,
|
||||
importance_weights: np.ndarray=None,
|
||||
no_accumulation: bool=False) ->\
|
||||
Tuple[float, List[float], float, list]:
|
||||
no_accumulation: bool=False) -> Tuple[float, List[float], float, list]:
|
||||
"""
|
||||
Given a batch of inputs (i.e. states) and targets (e.g. discounted rewards), computes and accumulates the
|
||||
gradients for model parameters. Will run forward and backward pass to compute gradients, clip the gradient
|
||||
@@ -142,30 +160,33 @@ class Architecture(object):
|
||||
calculated gradients
|
||||
:return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors
|
||||
total_loss (float): sum of all head losses
|
||||
losses (list of float): list of all losses. The order is list of target losses followed by list of regularization losses.
|
||||
The specifics of losses is dependant on the network parameters (number of heads, etc.)
|
||||
losses (list of float): list of all losses. The order is list of target losses followed by list of
|
||||
regularization losses. The specifics of losses is dependant on the network parameters
|
||||
(number of heads, etc.)
|
||||
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
|
||||
fetched_tensors: all values for additional_fetches
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply_and_reset_gradients(self, gradients: List[np.ndarray]) -> None:
|
||||
def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
||||
"""
|
||||
Applies the given gradients to the network weights and resets the gradient accumulations.
|
||||
Has the same impact as calling `apply_gradients`, then `reset_accumulated_gradients`.
|
||||
|
||||
:param gradients: gradients for the parameter weights, taken from `accumulated_gradients` property
|
||||
of an identical network (either self or another identical network)
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply_gradients(self, gradients: List[np.ndarray]) -> None:
|
||||
def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
||||
"""
|
||||
Applies the given gradients to the network weights.
|
||||
Will be performed sync or async depending on `network_parameters.async_training`
|
||||
|
||||
:param gradients: gradients for the parameter weights, taken from `accumulated_gradients` property
|
||||
of an identical network (either self or another identical network)
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
0
rl_coach/architectures/mxnet_components/__init__.py
Normal file
0
rl_coach/architectures/mxnet_components/__init__.py
Normal file
405
rl_coach/architectures/mxnet_components/architecture.py
Normal file
405
rl_coach/architectures/mxnet_components/architecture.py
Normal file
@@ -0,0 +1,405 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import copy
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import mxnet as mx
|
||||
from mxnet import autograd, gluon, nd
|
||||
from mxnet.ndarray import NDArray
|
||||
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list
|
||||
|
||||
|
||||
class MxnetArchitecture(Architecture):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "",
|
||||
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
|
||||
"""
|
||||
:param agent_parameters: the agent parameters
|
||||
:param spaces: the spaces definition of the agent
|
||||
:param name: the name of the network
|
||||
:param global_network: the global network replica that is shared between all the workers
|
||||
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
|
||||
:param network_is_trainable: is the network trainable (we can apply gradients on it)
|
||||
"""
|
||||
super().__init__(agent_parameters, spaces, name)
|
||||
self.middleware = None
|
||||
self.network_is_local = network_is_local
|
||||
self.global_network = global_network
|
||||
if not self.network_parameters.tensorflow_support:
|
||||
raise ValueError('TensorFlow is not supported for this agent')
|
||||
self.losses = [] # type: List[HeadLoss]
|
||||
self.shared_accumulated_gradients = []
|
||||
self.curr_rnn_c_in = None
|
||||
self.curr_rnn_h_in = None
|
||||
self.gradients_wrt_inputs = []
|
||||
self.train_writer = None
|
||||
self.accumulated_gradients = None
|
||||
self.network_is_trainable = network_is_trainable
|
||||
self.is_training = False
|
||||
self.model = None # type: GeneralModel
|
||||
|
||||
self.is_chief = self.ap.task_parameters.task_index == 0
|
||||
self.network_is_global = not self.network_is_local and global_network is None
|
||||
self.distributed_training = self.network_is_global or self.network_is_local and global_network is not None
|
||||
|
||||
self.optimizer_type = self.network_parameters.optimizer_type
|
||||
if self.ap.task_parameters.seed is not None:
|
||||
mx.random.seed(self.ap.task_parameters.seed)
|
||||
|
||||
# Call to child class to create the model
|
||||
self.construct_model()
|
||||
|
||||
self.trainer = None # type: gluon.Trainer
|
||||
|
||||
def __str__(self):
|
||||
return self.model.summary(*self._dummy_model_inputs())
|
||||
|
||||
@property
|
||||
def _model_grads(self) -> Generator[NDArray, NDArray, Any]:
|
||||
"""
|
||||
Creates a copy of model gradients and returns them in a list, in the same order as collect_params()
|
||||
:return: a generator for model gradient values
|
||||
"""
|
||||
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||
|
||||
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
|
||||
"""
|
||||
Creates a tuple of input arrays with correct shapes that can be used for shape inference
|
||||
of the model weights and for printing the summary
|
||||
:return: tuple of inputs for model forward pass
|
||||
"""
|
||||
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
|
||||
allowed_inputs["action"] = copy.copy(self.spaces.action)
|
||||
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
|
||||
embedders = self.model.nets[0].input_embedders
|
||||
inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders)
|
||||
return inputs
|
||||
|
||||
def construct_model(self) -> None:
|
||||
"""
|
||||
Construct network model. Implemented by child class.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
Initializes the model parameters and creates the model trainer.
|
||||
NOTEL Session for mxnet backend must be None.
|
||||
:param sess: must be None
|
||||
"""
|
||||
assert sess is None
|
||||
# FIXME Add GPU initialization
|
||||
# FIXME Add initializer
|
||||
self.model.collect_params().initialize(ctx=mx.cpu())
|
||||
# Hybridize model and losses
|
||||
self.model.hybridize()
|
||||
for l in self.losses:
|
||||
l.hybridize()
|
||||
|
||||
# Pass dummy data with correct shape to trigger shape inference and full parameter initialization
|
||||
self.model(*self._dummy_model_inputs())
|
||||
|
||||
if self.network_is_trainable:
|
||||
self.trainer = gluon.Trainer(
|
||||
self.model.collect_params(), optimizer=self.optimizer, update_on_kvstore=False)
|
||||
|
||||
def reset_accumulated_gradients(self) -> None:
|
||||
"""
|
||||
Reset model gradients as well as accumulated gradients to zero. If accumulated gradients
|
||||
have not been created yet, it constructs them on CPU.
|
||||
"""
|
||||
# Set model gradients to zero
|
||||
for p in self.model.collect_params().values():
|
||||
p.zero_grad()
|
||||
# Set accumulated gradients to zero if already initialized, otherwise create a copy
|
||||
if self.accumulated_gradients:
|
||||
for a in self.accumulated_gradients:
|
||||
a *= 0
|
||||
else:
|
||||
self.accumulated_gradients = [g.copy() for g in self._model_grads]
|
||||
|
||||
def accumulate_gradients(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
targets: List[np.ndarray],
|
||||
additional_fetches: List[Tuple[int, str]] = None,
|
||||
importance_weights: np.ndarray = None,
|
||||
no_accumulation: bool = False) -> Tuple[float, List[float], float, list]:
|
||||
"""
|
||||
Runs a forward & backward pass, clips gradients if needed and accumulates them into the accumulation
|
||||
:param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray
|
||||
is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size)
|
||||
:param targets: targets required by loss (e.g. sum of discounted rewards)
|
||||
:param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str)
|
||||
tuple of head-type-index and fetch-name. The tuple is obtained from each head.
|
||||
:param importance_weights: ndarray of shape (batch_size,) to multiply with batch loss.
|
||||
:param no_accumulation: if True, set gradient values to the new gradients, otherwise sum with previously
|
||||
calculated gradients
|
||||
:return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors
|
||||
total_loss (float): sum of all head losses
|
||||
losses (list of float): list of all losses. The order is list of target losses followed by list of
|
||||
regularization losses. The specifics of losses is dependant on the network parameters
|
||||
(number of heads, etc.)
|
||||
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
|
||||
fetched_tensors: all values for additional_fetches
|
||||
"""
|
||||
if self.accumulated_gradients is None:
|
||||
self.reset_accumulated_gradients()
|
||||
|
||||
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
||||
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
|
||||
|
||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported"
|
||||
|
||||
targets = force_list(targets)
|
||||
with autograd.record():
|
||||
out_per_head = utils.split_outputs_per_head(self.model(*nd_inputs), self.model.output_heads)
|
||||
tgt_per_loss = utils.split_targets_per_loss(targets, self.losses)
|
||||
|
||||
losses = list()
|
||||
regularizations = list()
|
||||
additional_fetches = [(k, None) for k in additional_fetches]
|
||||
for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss):
|
||||
l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss)
|
||||
# Align arguments with loss.loss_forward and convert to NDArray
|
||||
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss))
|
||||
# Calculate loss and all auxiliary outputs
|
||||
loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema)
|
||||
if LOSS_OUT_TYPE_LOSS in loss_outputs:
|
||||
losses.extend(loss_outputs[LOSS_OUT_TYPE_LOSS])
|
||||
if LOSS_OUT_TYPE_REGULARIZATION in loss_outputs:
|
||||
regularizations.extend(loss_outputs[LOSS_OUT_TYPE_REGULARIZATION])
|
||||
# Set additional fetches
|
||||
for i, fetch in enumerate(additional_fetches):
|
||||
head_type_idx, fetch_name = fetch[0] # fetch key is a tuple of (head_type_index, fetch_name)
|
||||
if head_type_idx == h.head_type_idx:
|
||||
assert fetch[1] is None # sanity check that fetch is None
|
||||
additional_fetches[i] = (fetch[0], loss_outputs[fetch_name])
|
||||
|
||||
# Total loss is losses and regularization (NOTE: order is important)
|
||||
total_loss_list = losses + regularizations
|
||||
total_loss = nd.add_n(*total_loss_list)
|
||||
|
||||
# Calculate gradients
|
||||
total_loss.backward()
|
||||
|
||||
assert self.optimizer_type != 'LBFGS', 'LBFGS not supported'
|
||||
|
||||
# allreduce gradients from all contexts
|
||||
self.trainer.allreduce_grads()
|
||||
|
||||
# Calculate global norm of gradients
|
||||
# FIXME global norm is returned even when not used for clipping! Is this necessary?
|
||||
# FIXME global norm might be calculated twice if clipping method is global norm
|
||||
norm_unclipped_grads = utils.global_norm(self._model_grads)
|
||||
|
||||
# Clip gradients
|
||||
if self.network_parameters.clip_gradients:
|
||||
utils.clip_grad(
|
||||
self._model_grads,
|
||||
clip_method=self.network_parameters.gradients_clipping_method,
|
||||
clip_val=self.network_parameters.clip_gradients,
|
||||
inplace=True)
|
||||
|
||||
# Update self.accumulated_gradients depending on no_accumulation flag
|
||||
if no_accumulation:
|
||||
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
|
||||
acc_grad[:] = model_grad
|
||||
else:
|
||||
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
|
||||
acc_grad += model_grad
|
||||
|
||||
# result of of additional fetches
|
||||
fetched_tensors = [fetch[1] for fetch in additional_fetches]
|
||||
|
||||
# convert everything to numpy or scalar before returning
|
||||
result = utils.asnumpy_or_asscalar((total_loss, total_loss_list, norm_unclipped_grads, fetched_tensors))
|
||||
return result
|
||||
|
||||
def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
||||
"""
|
||||
Applies the given gradients to the network weights and resets accumulated gradients to zero
|
||||
:param gradients: The gradients to use for the update
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
"""
|
||||
self.apply_gradients(gradients, scaler)
|
||||
self.reset_accumulated_gradients()
|
||||
|
||||
def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
||||
"""
|
||||
Applies the given gradients to the network weights
|
||||
:param gradients: The gradients to use for the update
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
|
||||
The gradients will be MULTIPLIED by this factor
|
||||
"""
|
||||
assert self.optimizer_type != 'LBFGS'
|
||||
|
||||
batch_size = 1
|
||||
if self.distributed_training and not self.network_parameters.async_training:
|
||||
# rescale the gradients so that they average out with the gradients from the other workers
|
||||
if self.network_parameters.scale_down_gradients_by_number_of_workers_for_sync_training:
|
||||
batch_size = self.ap.task_parameters.num_training_tasks
|
||||
|
||||
# set parameter gradients to gradients passed in
|
||||
for param_grad, gradient in zip(self._model_grads, gradients):
|
||||
param_grad[:] = gradient
|
||||
# update gradients
|
||||
self.trainer.update(batch_size=batch_size)
|
||||
|
||||
def _predict(self, inputs: Dict[str, np.ndarray]) -> Tuple[NDArray, ...]:
|
||||
"""
|
||||
Run a forward pass of the network using the given input
|
||||
:param inputs: The input dictionary for the network. Key is name of the embedder.
|
||||
:return: The network output
|
||||
|
||||
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
||||
"""
|
||||
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
||||
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
|
||||
|
||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware'
|
||||
|
||||
output = self.model(*nd_inputs)
|
||||
return output
|
||||
|
||||
def predict(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
outputs: List[str]=None,
|
||||
squeeze_output: bool=True,
|
||||
initial_feed_dict: Dict[str, np.ndarray]=None) -> Tuple[np.ndarray, ...]:
|
||||
"""
|
||||
Run a forward pass of the network using the given input
|
||||
:param inputs: The input dictionary for the network. Key is name of the embedder.
|
||||
:param outputs: list of outputs to return. Return all outputs if unspecified (currently not supported)
|
||||
:param squeeze_output: call squeeze_list on output if True
|
||||
:param initial_feed_dict: a dictionary of extra inputs for forward pass (currently not supported)
|
||||
:return: The network output
|
||||
|
||||
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
||||
"""
|
||||
assert initial_feed_dict is None, "initial_feed_dict must be None"
|
||||
assert outputs is None, "outputs must be None"
|
||||
|
||||
output = self._predict(inputs)
|
||||
|
||||
output = tuple(o.asnumpy() for o in output)
|
||||
if squeeze_output:
|
||||
output = squeeze_list(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def parallel_predict(sess: Any,
|
||||
network_input_tuples: List[Tuple['MxnetArchitecture', Dict[str, np.ndarray]]]) -> \
|
||||
Tuple[np.ndarray, ...]:
|
||||
"""
|
||||
:param sess: active session to use for prediction (must be None for MXNet)
|
||||
:param network_input_tuples: tuple of network and corresponding input
|
||||
:return: tuple of outputs from all networks
|
||||
"""
|
||||
assert sess is None
|
||||
output = list()
|
||||
for net, inputs in network_input_tuples:
|
||||
output += net._predict(inputs)
|
||||
return tuple(o.asnumpy() for o in output)
|
||||
|
||||
def train_on_batch(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
targets: List[np.ndarray],
|
||||
scaler: float = 1.,
|
||||
additional_fetches: list = None,
|
||||
importance_weights: np.ndarray = None) -> Tuple[float, List[float], float, list]:
|
||||
"""
|
||||
Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a
|
||||
forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to
|
||||
update the weights.
|
||||
:param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray
|
||||
is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size)
|
||||
:param targets: targets required by loss (e.g. sum of discounted rewards)
|
||||
:param scaler: value to scale gradients by before optimizing network weights
|
||||
:param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str)
|
||||
tuple of head-type-index and fetch-name. The tuple is obtained from each head.
|
||||
:param importance_weights: ndarray of shape (batch_size,) to multiply with batch loss.
|
||||
:return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors
|
||||
total_loss (float): sum of all head losses
|
||||
losses (list of float): list of all losses. The order is list of target losses followed by list
|
||||
of regularization losses. The specifics of losses is dependant on the network parameters
|
||||
(number of heads, etc.)
|
||||
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
|
||||
fetched_tensors: all values for additional_fetches
|
||||
"""
|
||||
loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
|
||||
importance_weights=importance_weights)
|
||||
self.apply_and_reset_gradients(self.accumulated_gradients, scaler)
|
||||
return loss
|
||||
|
||||
def get_weights(self) -> gluon.ParameterDict:
|
||||
"""
|
||||
:return: a ParameterDict containing all network weights
|
||||
"""
|
||||
return self.model.collect_params()
|
||||
|
||||
def set_weights(self, weights: gluon.ParameterDict, new_rate: float=1.0) -> None:
|
||||
"""
|
||||
Sets the network weights from the given ParameterDict
|
||||
:param new_rate: ratio for adding new and old weight values: val=rate*weights + (1-rate)*old_weights
|
||||
"""
|
||||
old_weights = self.model.collect_params()
|
||||
for name, p in weights.items():
|
||||
name = name[len(weights.prefix):] # Strip prefix
|
||||
old_p = old_weights[old_weights.prefix + name] # Add prefix
|
||||
old_p.set_data(new_rate * p._reduce() + (1 - new_rate) * old_p._reduce())
|
||||
|
||||
def get_variable_value(self, variable: Union[gluon.Parameter, NDArray]) -> np.ndarray:
|
||||
"""
|
||||
Get the value of a variable
|
||||
:param variable: the variable
|
||||
:return: the value of the variable
|
||||
"""
|
||||
if isinstance(variable, gluon.Parameter):
|
||||
variable = variable._reduce().asnumpy()
|
||||
if isinstance(variable, NDArray):
|
||||
return variable.asnumpy()
|
||||
return variable
|
||||
|
||||
def set_variable_value(self, assign_op: callable, value: Any, placeholder=None) -> None:
|
||||
"""
|
||||
Updates value of a variable.
|
||||
:param assign_op: a callable assign function for setting the variable
|
||||
:param value: a value to set the variable to
|
||||
:param placeholder: unused (placeholder in symbolic framework backends)
|
||||
"""
|
||||
assert callable(assign_op)
|
||||
assign_op(value)
|
||||
|
||||
def set_is_training(self, state: bool) -> None:
|
||||
"""
|
||||
Set the phase of the network between training and testing
|
||||
:param state: The current state (True = Training, False = Testing)
|
||||
:return: None
|
||||
"""
|
||||
self.is_training = state
|
||||
|
||||
def reset_internal_memory(self) -> None:
|
||||
"""
|
||||
Reset any internal memory used by the network. For example, an LSTM internal state
|
||||
:return: None
|
||||
"""
|
||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'
|
||||
@@ -0,0 +1,4 @@
|
||||
from .image_embedder import ImageEmbedder
|
||||
from .vector_embedder import VectorEmbedder
|
||||
|
||||
__all__ = ['ImageEmbedder', 'VectorEmbedder']
|
||||
@@ -0,0 +1,71 @@
|
||||
from typing import Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.mxnet_components.layers import convert_layer
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class InputEmbedder(nn.HybridBlock):
|
||||
def __init__(self, params: InputEmbedderParameters):
|
||||
"""
|
||||
An input embedder is the first part of the network, which takes the input from the state and produces a vector
|
||||
embedding by passing it through a neural network. The embedder will mostly be input type dependent, and there
|
||||
can be multiple embedders in a single network.
|
||||
|
||||
:param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function
|
||||
and dropout properties.
|
||||
"""
|
||||
super(InputEmbedder, self).__init__()
|
||||
self.embedder_name = params.name
|
||||
self.input_clipping = params.input_clipping
|
||||
self.scheme = params.scheme
|
||||
|
||||
with self.name_scope():
|
||||
self.net = nn.HybridSequential()
|
||||
if isinstance(self.scheme, EmbedderScheme):
|
||||
blocks = self.schemes[self.scheme]
|
||||
else:
|
||||
# if scheme is specified directly, convert to MX layer if it's not a callable object
|
||||
# NOTE: if layer object is callable, it must return a gluon block when invoked
|
||||
blocks = [convert_layer(l) for l in self.scheme]
|
||||
for block in blocks:
|
||||
self.net.add(block())
|
||||
if params.batchnorm:
|
||||
self.net.add(nn.BatchNorm())
|
||||
if params.activation_function:
|
||||
self.net.add(nn.Activation(params.activation_function))
|
||||
if params.dropout:
|
||||
self.net.add(nn.Dropout(rate=params.dropout))
|
||||
|
||||
@property
|
||||
def schemes(self) -> dict:
|
||||
"""
|
||||
Schemes are the pre-defined network architectures of various depths and complexities that can be used for the
|
||||
InputEmbedder. Should be implemented in child classes, and are used to create Block when InputEmbedder is
|
||||
initialised.
|
||||
|
||||
:return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block.
|
||||
"""
|
||||
raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default "
|
||||
"configurations.")
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through embedder network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: environment state, where first dimension is batch_size, then dimensions are data type dependent.
|
||||
:return: embedding of environment state, where shape is (batch_size, channels).
|
||||
"""
|
||||
# `input_rescaling` and `input_offset` set on inheriting embedder
|
||||
x = x / self.input_rescaling
|
||||
x = x - self.input_offset
|
||||
if self.input_clipping is not None:
|
||||
x.clip(a_min=self.input_clipping[0], a_max=self.input_clipping[1])
|
||||
x = self.net(x)
|
||||
return x.flatten()
|
||||
@@ -0,0 +1,76 @@
|
||||
from typing import Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.mxnet_components.embedders.embedder import InputEmbedder
|
||||
from rl_coach.architectures.mxnet_components.layers import Conv2d
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class ImageEmbedder(InputEmbedder):
|
||||
def __init__(self, params: InputEmbedderParameters):
|
||||
"""
|
||||
An image embedder is an input embedder that takes an image input from the state and produces a vector
|
||||
embedding by passing it through a neural network.
|
||||
|
||||
:param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function
|
||||
and dropout properties.
|
||||
"""
|
||||
super(ImageEmbedder, self).__init__(params)
|
||||
self.input_rescaling = params.input_rescaling['image']
|
||||
self.input_offset = params.input_offset['image']
|
||||
|
||||
@property
|
||||
def schemes(self) -> dict:
|
||||
"""
|
||||
Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used
|
||||
to create Block when ImageEmbedder is initialised.
|
||||
|
||||
:return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block.
|
||||
"""
|
||||
return {
|
||||
EmbedderScheme.Empty:
|
||||
[],
|
||||
|
||||
EmbedderScheme.Shallow:
|
||||
[
|
||||
Conv2d(num_filters=32, kernel_size=8, strides=4)
|
||||
],
|
||||
|
||||
# Use for Atari DQN
|
||||
EmbedderScheme.Medium:
|
||||
[
|
||||
Conv2d(num_filters=32, kernel_size=8, strides=4),
|
||||
Conv2d(num_filters=64, kernel_size=4, strides=2),
|
||||
Conv2d(num_filters=64, kernel_size=3, strides=1)
|
||||
],
|
||||
|
||||
# Use for Carla
|
||||
EmbedderScheme.Deep:
|
||||
[
|
||||
Conv2d(num_filters=32, kernel_size=5, strides=2),
|
||||
Conv2d(num_filters=32, kernel_size=3, strides=1),
|
||||
Conv2d(num_filters=64, kernel_size=3, strides=2),
|
||||
Conv2d(num_filters=64, kernel_size=3, strides=1),
|
||||
Conv2d(num_filters=128, kernel_size=3, strides=2),
|
||||
Conv2d(num_filters=128, kernel_size=3, strides=1),
|
||||
Conv2d(num_filters=256, kernel_size=3, strides=2),
|
||||
Conv2d(num_filters=256, kernel_size=3, strides=1)
|
||||
]
|
||||
}
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through embedder network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: image representing environment state, of shape (batch_size, in_channels, height, width).
|
||||
:return: embedding of environment state, of shape (batch_size, channels).
|
||||
"""
|
||||
if len(x.shape) != 4 and self.scheme != EmbedderScheme.Empty:
|
||||
raise ValueError("Image embedders expect the input size to have 4 dimensions. The given size is: {}"
|
||||
.format(x.shape))
|
||||
return super(ImageEmbedder, self).hybrid_forward(F, x, *args, **kwargs)
|
||||
@@ -0,0 +1,71 @@
|
||||
from typing import Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet import nd, sym
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.mxnet_components.embedders.embedder import InputEmbedder
|
||||
from rl_coach.architectures.mxnet_components.layers import Dense
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class VectorEmbedder(InputEmbedder):
|
||||
def __init__(self, params: InputEmbedderParameters):
|
||||
"""
|
||||
An vector embedder is an input embedder that takes an vector input from the state and produces a vector
|
||||
embedding by passing it through a neural network.
|
||||
|
||||
:param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function
|
||||
and dropout properties.
|
||||
"""
|
||||
super(VectorEmbedder, self).__init__(params)
|
||||
self.input_rescaling = params.input_rescaling['vector']
|
||||
self.input_offset = params.input_offset['vector']
|
||||
|
||||
@property
|
||||
def schemes(self):
|
||||
"""
|
||||
Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used
|
||||
to create Block when VectorEmbedder is initialised.
|
||||
|
||||
:return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block.
|
||||
"""
|
||||
return {
|
||||
EmbedderScheme.Empty:
|
||||
[],
|
||||
|
||||
EmbedderScheme.Shallow:
|
||||
[
|
||||
Dense(units=128)
|
||||
],
|
||||
|
||||
# Use for DQN
|
||||
EmbedderScheme.Medium:
|
||||
[
|
||||
Dense(units=256)
|
||||
],
|
||||
|
||||
# Use for Carla
|
||||
EmbedderScheme.Deep:
|
||||
[
|
||||
Dense(units=128),
|
||||
Dense(units=128),
|
||||
Dense(units=128)
|
||||
]
|
||||
}
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through embedder network.
|
||||
|
||||
:param F: backend api, either `nd` or `sym` (if block has been hybridized).
|
||||
:type F: nd or sym
|
||||
:param x: vector representing environment state, of shape (batch_size, in_channels).
|
||||
:return: embedding of environment state, of shape (batch_size, channels).
|
||||
"""
|
||||
if isinstance(x, nd.NDArray) and len(x.shape) != 2 and self.scheme != EmbedderScheme.Empty:
|
||||
raise ValueError("Vector embedders expect the input size to have 2 dimensions. The given size is: {}"
|
||||
.format(x.shape))
|
||||
return super(VectorEmbedder, self).hybrid_forward(F, x, *args, **kwargs)
|
||||
501
rl_coach/architectures/mxnet_components/general_network.py
Normal file
501
rl_coach/architectures/mxnet_components/general_network.py
Normal file
@@ -0,0 +1,501 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
from itertools import chain
|
||||
from typing import List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import numpy as np
|
||||
import mxnet as mx
|
||||
from mxnet import nd, sym
|
||||
from mxnet.gluon import HybridBlock
|
||||
from mxnet.ndarray import NDArray
|
||||
from mxnet.symbol import Symbol
|
||||
|
||||
from rl_coach.base_parameters import NetworkParameters
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.head_parameters import HeadParameters, PPOHeadParameters
|
||||
from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadParameters, QHeadParameters
|
||||
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
||||
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters
|
||||
from rl_coach.architectures.mxnet_components.architecture import MxnetArchitecture
|
||||
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, VectorEmbedder
|
||||
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
|
||||
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
|
||||
|
||||
|
||||
class GeneralMxnetNetwork(MxnetArchitecture):
|
||||
"""
|
||||
A generalized version of all possible networks implemented using mxnet.
|
||||
"""
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
name: str,
|
||||
global_network=None,
|
||||
network_is_local: bool=True,
|
||||
network_is_trainable: bool=False):
|
||||
"""
|
||||
:param agent_parameters: the agent parameters
|
||||
:param spaces: the spaces definition of the agent
|
||||
:param name: the name of the network
|
||||
:param global_network: the global network replica that is shared between all the workers
|
||||
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
|
||||
:param network_is_trainable: is the network trainable (we can apply gradients on it)
|
||||
"""
|
||||
self.network_wrapper_name = name.split('/')[0]
|
||||
self.network_parameters = agent_parameters.network_wrappers[self.network_wrapper_name]
|
||||
if self.network_parameters.use_separate_networks_per_head:
|
||||
self.num_heads_per_network = 1
|
||||
self.num_networks = len(self.network_parameters.heads_parameters)
|
||||
else:
|
||||
self.num_heads_per_network = len(self.network_parameters.heads_parameters)
|
||||
self.num_networks = 1
|
||||
|
||||
super().__init__(agent_parameters, spaces, name, global_network,
|
||||
network_is_local, network_is_trainable)
|
||||
|
||||
def construct_model(self):
|
||||
# validate the configuration
|
||||
if len(self.network_parameters.input_embedders_parameters) == 0:
|
||||
raise ValueError("At least one input type should be defined")
|
||||
|
||||
if len(self.network_parameters.heads_parameters) == 0:
|
||||
raise ValueError("At least one output type should be defined")
|
||||
|
||||
if self.network_parameters.middleware_parameters is None:
|
||||
raise ValueError("Exactly one middleware type should be defined")
|
||||
|
||||
self.model = GeneralModel(
|
||||
num_networks=self.num_networks,
|
||||
num_heads_per_network=self.num_heads_per_network,
|
||||
network_is_local=self.network_is_local,
|
||||
network_name=self.network_wrapper_name,
|
||||
agent_parameters=self.ap,
|
||||
network_parameters=self.network_parameters,
|
||||
spaces=self.spaces)
|
||||
|
||||
self.losses = self.model.losses()
|
||||
|
||||
# Learning rate
|
||||
lr_scheduler = None
|
||||
if self.network_parameters.learning_rate_decay_rate != 0:
|
||||
lr_scheduler = mx.lr_scheduler.FactorScheduler(
|
||||
step=self.network_parameters.learning_rate_decay_steps,
|
||||
factor=self.network_parameters.learning_rate_decay_rate)
|
||||
|
||||
# Optimizer
|
||||
# FIXME Does this code for distributed training make sense?
|
||||
if self.distributed_training and self.network_is_local and self.network_parameters.shared_optimizer:
|
||||
# distributed training + is a local network + optimizer shared -> take the global optimizer
|
||||
self.optimizer = self.global_network.optimizer
|
||||
elif (self.distributed_training and self.network_is_local and not self.network_parameters.shared_optimizer)\
|
||||
or self.network_parameters.shared_optimizer or not self.distributed_training:
|
||||
|
||||
if self.network_parameters.optimizer_type == 'Adam':
|
||||
self.optimizer = mx.optimizer.Adam(
|
||||
learning_rate=self.network_parameters.learning_rate,
|
||||
beta1=self.network_parameters.adam_optimizer_beta1,
|
||||
beta2=self.network_parameters.adam_optimizer_beta2,
|
||||
epsilon=self.network_parameters.optimizer_epsilon,
|
||||
lr_scheduler=lr_scheduler)
|
||||
elif self.network_parameters.optimizer_type == 'RMSProp':
|
||||
self.optimizer = mx.optimizer.RMSProp(
|
||||
learning_rate=self.network_parameters.learning_rate,
|
||||
gamma1=self.network_parameters.rms_prop_optimizer_decay,
|
||||
epsilon=self.network_parameters.optimizer_epsilon,
|
||||
lr_scheduler=lr_scheduler)
|
||||
elif self.network_parameters.optimizer_type == 'LBFGS':
|
||||
raise NotImplementedError('LBFGS optimizer not implemented')
|
||||
else:
|
||||
raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type))
|
||||
|
||||
@property
|
||||
def output_heads(self):
|
||||
return self.model.output_heads
|
||||
|
||||
|
||||
def _get_activation(activation_function_string: str):
|
||||
"""
|
||||
Map the activation function from a string to the mxnet framework equivalent
|
||||
:param activation_function_string: the type of the activation function
|
||||
:return: mxnet activation function string
|
||||
"""
|
||||
return utils.get_mxnet_activation_name(activation_function_string)
|
||||
|
||||
|
||||
def _sanitize_activation(params: Union[InputEmbedderParameters, MiddlewareParameters, HeadParameters]) ->\
|
||||
Union[InputEmbedderParameters, MiddlewareParameters, HeadParameters]:
|
||||
"""
|
||||
Change activation function to the mxnet specific value
|
||||
:param params: any parameter that has activation_function property
|
||||
:return: copy of params with activation function correctly set
|
||||
"""
|
||||
params_copy = copy.copy(params)
|
||||
params_copy.activation_function = _get_activation(params.activation_function)
|
||||
return params_copy
|
||||
|
||||
|
||||
def _get_input_embedder(spaces: SpacesDefinition,
|
||||
input_name: str,
|
||||
embedder_params: InputEmbedderParameters) -> ModuleType:
|
||||
"""
|
||||
Given an input embedder parameters class, creates the input embedder and returns it
|
||||
:param input_name: the name of the input to the embedder (used for retrieving the shape). The input should
|
||||
be a value within the state or the action.
|
||||
:param embedder_params: the parameters of the class of the embedder
|
||||
:return: the embedder instance
|
||||
"""
|
||||
allowed_inputs = copy.copy(spaces.state.sub_spaces)
|
||||
allowed_inputs["action"] = copy.copy(spaces.action)
|
||||
allowed_inputs["goal"] = copy.copy(spaces.goal)
|
||||
|
||||
if input_name not in allowed_inputs.keys():
|
||||
raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}"
|
||||
.format(input_name, allowed_inputs.keys()))
|
||||
|
||||
type = "vector"
|
||||
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
||||
type = "image"
|
||||
|
||||
def sanitize_params(params: InputEmbedderParameters):
|
||||
params_copy = _sanitize_activation(params)
|
||||
# params_copy.input_rescaling = params_copy.input_rescaling[type]
|
||||
# params_copy.input_offset = params_copy.input_offset[type]
|
||||
params_copy.name = input_name
|
||||
return params_copy
|
||||
|
||||
embedder_params = sanitize_params(embedder_params)
|
||||
if type == 'vector':
|
||||
module = VectorEmbedder(embedder_params)
|
||||
elif type == 'image':
|
||||
module = ImageEmbedder(embedder_params)
|
||||
else:
|
||||
raise KeyError('Unsupported embedder type: {}'.format(type))
|
||||
return module
|
||||
|
||||
|
||||
def _get_middleware(middleware_params: MiddlewareParameters) -> ModuleType:
|
||||
"""
|
||||
Given a middleware type, creates the middleware and returns it
|
||||
:param middleware_params: the paramaeters of the middleware class
|
||||
:return: the middleware instance
|
||||
"""
|
||||
middleware_params = _sanitize_activation(middleware_params)
|
||||
if isinstance(middleware_params, FCMiddlewareParameters):
|
||||
module = FCMiddleware(middleware_params)
|
||||
elif isinstance(middleware_params, LSTMMiddlewareParameters):
|
||||
module = LSTMMiddleware(middleware_params)
|
||||
else:
|
||||
raise KeyError('Unsupported middleware type: {}'.format(type(middleware_params)))
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _get_output_head(
|
||||
head_params: HeadParameters,
|
||||
head_idx: int,
|
||||
head_type_index: int,
|
||||
agent_params: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
network_name: str,
|
||||
is_local: bool) -> Head:
|
||||
"""
|
||||
Given a head type, creates the head and returns it
|
||||
:param head_params: the parameters of the head to create
|
||||
:param head_idx: the head index
|
||||
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
|
||||
:param agent_params: agent parameters
|
||||
:param spaces: state and action space definitions
|
||||
:param network_name: name of the network
|
||||
:param is_local:
|
||||
:return: head block
|
||||
"""
|
||||
head_params = _sanitize_activation(head_params)
|
||||
if isinstance(head_params, PPOHeadParameters):
|
||||
module = PPOHead(
|
||||
agent_parameters=agent_params,
|
||||
spaces=spaces,
|
||||
network_name=network_name,
|
||||
head_type_idx=head_type_index,
|
||||
loss_weight=head_params.loss_weight,
|
||||
is_local=is_local,
|
||||
activation_function=head_params.activation_function,
|
||||
dense_layer=head_params.dense_layer)
|
||||
elif isinstance(head_params, VHeadParameters):
|
||||
module = VHead(
|
||||
agent_parameters=agent_params,
|
||||
spaces=spaces,
|
||||
network_name=network_name,
|
||||
head_type_idx=head_type_index,
|
||||
loss_weight=head_params.loss_weight,
|
||||
is_local=is_local,
|
||||
activation_function=head_params.activation_function,
|
||||
dense_layer=head_params.dense_layer)
|
||||
elif isinstance(head_params, PPOVHeadParameters):
|
||||
module = PPOVHead(
|
||||
agent_parameters=agent_params,
|
||||
spaces=spaces,
|
||||
network_name=network_name,
|
||||
head_type_idx=head_type_index,
|
||||
loss_weight=head_params.loss_weight,
|
||||
is_local=is_local,
|
||||
activation_function=head_params.activation_function,
|
||||
dense_layer=head_params.dense_layer)
|
||||
elif isinstance(head_params, QHeadParameters):
|
||||
module = QHead(
|
||||
agent_parameters=agent_params,
|
||||
spaces=spaces,
|
||||
network_name=network_name,
|
||||
head_type_idx=head_type_index,
|
||||
loss_weight=head_params.loss_weight,
|
||||
is_local=is_local,
|
||||
activation_function=head_params.activation_function,
|
||||
dense_layer=head_params.dense_layer)
|
||||
else:
|
||||
raise KeyError('Unsupported head type: {}'.format(type(head_params)))
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class ScaledGradHead(HybridBlock):
|
||||
"""
|
||||
Wrapper block for applying gradient scaling to input before feeding the head network
|
||||
"""
|
||||
def __init__(self,
|
||||
head_index: int,
|
||||
head_type_index: int,
|
||||
network_name: str,
|
||||
spaces: SpacesDefinition,
|
||||
network_is_local: bool,
|
||||
agent_params: AgentParameters,
|
||||
head_params: HeadParameters) -> None:
|
||||
"""
|
||||
:param head_idx: the head index
|
||||
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
|
||||
:param network_name: name of the network
|
||||
:param spaces: state and action space definitions
|
||||
:param network_is_local: whether network is local
|
||||
:param agent_params: agent parameters
|
||||
:param head_params: head parameters
|
||||
"""
|
||||
super(ScaledGradHead, self).__init__()
|
||||
|
||||
head_params = _sanitize_activation(head_params)
|
||||
with self.name_scope():
|
||||
self.head = _get_output_head(
|
||||
head_params=head_params,
|
||||
head_idx=head_index,
|
||||
head_type_index=head_type_index,
|
||||
agent_params=agent_params,
|
||||
spaces=spaces,
|
||||
network_name=network_name,
|
||||
is_local=network_is_local)
|
||||
self.gradient_rescaler = self.params.get_constant(
|
||||
name='gradient_rescaler',
|
||||
value=np.array([float(head_params.rescale_gradient_from_head_by_factor)]))
|
||||
# self.gradient_rescaler = self.params.get(
|
||||
# name='gradient_rescaler',
|
||||
# shape=(1,),
|
||||
# init=mx.init.Constant(float(head_params.rescale_gradient_from_head_by_factor)))
|
||||
|
||||
def hybrid_forward(self,
|
||||
F: ModuleType,
|
||||
x: Union[NDArray, Symbol],
|
||||
gradient_rescaler: Union[NDArray, Symbol]) -> Tuple[Union[NDArray, Symbol], ...]:
|
||||
""" Overrides gluon.HybridBlock.hybrid_forward
|
||||
:param nd or sym F: ndarray or symbol module
|
||||
:param x: head input
|
||||
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
|
||||
:return: head output
|
||||
"""
|
||||
grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x)
|
||||
out = self.head(grad_scaled_x)
|
||||
return out
|
||||
|
||||
|
||||
class SingleModel(HybridBlock):
|
||||
"""
|
||||
Block that connects a single embedder, with middleware and one to multiple heads
|
||||
"""
|
||||
def __init__(self,
|
||||
network_is_local: bool,
|
||||
network_name: str,
|
||||
agent_parameters: AgentParameters,
|
||||
in_emb_param_dict: {str: InputEmbedderParameters},
|
||||
embedding_merger_type: EmbeddingMergerType,
|
||||
middleware_param: MiddlewareParameters,
|
||||
head_param_list: [HeadParameters],
|
||||
head_type_idx_start: int,
|
||||
spaces: SpacesDefinition,
|
||||
*args, **kwargs):
|
||||
"""
|
||||
:param network_is_local: True if network is local
|
||||
:param network_name: name of the network
|
||||
:param agent_parameters: agent parameters
|
||||
:param in_emb_param_dict: dictionary of embedder name to embedding parameters
|
||||
:param embedding_merger_type: type of merging output of embedders: concatenate or sum
|
||||
:param middleware_param: middleware parameters
|
||||
:param head_param_list: list of head parameters, one per head type
|
||||
:param head_type_idx_start: start index for head type index counting
|
||||
:param spaces: state and action space definition
|
||||
"""
|
||||
super(SingleModel, self).__init__(*args, **kwargs)
|
||||
|
||||
self._embedding_merger_type = embedding_merger_type
|
||||
self._input_embedders = list() # type: List[HybridBlock]
|
||||
self._output_heads = list() # type: List[ScaledGradHead]
|
||||
|
||||
with self.name_scope():
|
||||
for input_name in sorted(in_emb_param_dict):
|
||||
input_type = in_emb_param_dict[input_name]
|
||||
input_embedder = _get_input_embedder(spaces, input_name, input_type)
|
||||
self.register_child(input_embedder)
|
||||
self._input_embedders.append(input_embedder)
|
||||
|
||||
self.middleware = _get_middleware(middleware_param)
|
||||
|
||||
for i, head_param in enumerate(head_param_list):
|
||||
for head_copy_idx in range(head_param.num_output_head_copies):
|
||||
# create output head and add it to the output heads list
|
||||
output_head = ScaledGradHead(
|
||||
head_index=(head_type_idx_start + i) * head_param.num_output_head_copies + head_copy_idx,
|
||||
head_type_index=head_type_idx_start + i,
|
||||
network_name=network_name,
|
||||
spaces=spaces,
|
||||
network_is_local=network_is_local,
|
||||
agent_params=agent_parameters,
|
||||
head_params=head_param)
|
||||
self.register_child(output_head)
|
||||
self._output_heads.append(output_head)
|
||||
|
||||
def hybrid_forward(self, F, *inputs: Union[NDArray, Symbol]) -> Tuple[Union[NDArray, Symbol], ...]:
|
||||
""" Overrides gluon.HybridBlock.hybrid_forward
|
||||
:param nd or sym F: ndarray or symbol block
|
||||
:param inputs: model inputs, one for each embedder
|
||||
:return: head outputs in a tuple
|
||||
"""
|
||||
# Input Embeddings
|
||||
state_embedding = list()
|
||||
for input, embedder in zip(inputs, self._input_embedders):
|
||||
state_embedding.append(embedder(input))
|
||||
|
||||
# Merger
|
||||
if len(state_embedding) == 1:
|
||||
state_embedding = state_embedding[0]
|
||||
else:
|
||||
if self._embedding_merger_type == EmbeddingMergerType.Concat:
|
||||
state_embedding = F.concat(*state_embedding, dim=1, name='merger') # NC or NCHW layout
|
||||
elif self._embedding_merger_type == EmbeddingMergerType.Sum:
|
||||
state_embedding = F.add_n(*state_embedding, name='merger')
|
||||
|
||||
# Middleware
|
||||
state_embedding = self.middleware(state_embedding)
|
||||
|
||||
# Head
|
||||
outputs = tuple()
|
||||
for head in self._output_heads:
|
||||
outputs += (head(state_embedding),)
|
||||
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def input_embedders(self) -> List[HybridBlock]:
|
||||
"""
|
||||
:return: list of input embedders
|
||||
"""
|
||||
return self._input_embedders
|
||||
|
||||
@property
|
||||
def output_heads(self) -> List[Head]:
|
||||
"""
|
||||
:return: list of output heads
|
||||
"""
|
||||
return [h.head for h in self._output_heads]
|
||||
|
||||
|
||||
class GeneralModel(HybridBlock):
|
||||
"""
|
||||
Block that creates multiple single models
|
||||
"""
|
||||
def __init__(self,
|
||||
num_networks: int,
|
||||
num_heads_per_network: int,
|
||||
network_is_local: bool,
|
||||
network_name: str,
|
||||
agent_parameters: AgentParameters,
|
||||
network_parameters: NetworkParameters,
|
||||
spaces: SpacesDefinition,
|
||||
*args, **kwargs):
|
||||
"""
|
||||
:param num_networks: number of networks to create
|
||||
:param num_heads_per_network: number of heads per network to create
|
||||
:param network_is_local: True if network is local
|
||||
:param network_name: name of the network
|
||||
:param agent_parameters: agent parameters
|
||||
:param network_parameters: network parameters
|
||||
:param spaces: state and action space definitions
|
||||
"""
|
||||
super(GeneralModel, self).__init__(*args, **kwargs)
|
||||
|
||||
with self.name_scope():
|
||||
self.nets = list()
|
||||
for network_idx in range(num_networks):
|
||||
head_type_idx_start = network_idx * num_heads_per_network
|
||||
head_type_idx_end = head_type_idx_start + num_heads_per_network
|
||||
net = SingleModel(
|
||||
head_type_idx_start=head_type_idx_start,
|
||||
network_name=network_name,
|
||||
network_is_local=network_is_local,
|
||||
agent_parameters=agent_parameters,
|
||||
in_emb_param_dict=network_parameters.input_embedders_parameters,
|
||||
embedding_merger_type=network_parameters.embedding_merger_type,
|
||||
middleware_param=network_parameters.middleware_parameters,
|
||||
head_param_list=network_parameters.heads_parameters[head_type_idx_start:head_type_idx_end],
|
||||
spaces=spaces)
|
||||
self.register_child(net)
|
||||
self.nets.append(net)
|
||||
|
||||
def hybrid_forward(self, F, *inputs):
|
||||
""" Overrides gluon.HybridBlock.hybrid_forward
|
||||
:param nd or sym F: ndarray or symbol block
|
||||
:param inputs: model inputs, one for each embedder. Passed to all networks.
|
||||
:return: head outputs in a tuple
|
||||
"""
|
||||
outputs = tuple()
|
||||
for net in self.nets:
|
||||
out = net(*inputs)
|
||||
outputs += out
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def output_heads(self) -> List[Head]:
|
||||
""" Return all heads in a single list
|
||||
Note: There is a one-to-one mapping between output_heads and losses
|
||||
:return: list of heads
|
||||
"""
|
||||
return list(chain.from_iterable(net.output_heads for net in self.nets))
|
||||
|
||||
def losses(self) -> List[HeadLoss]:
|
||||
""" Construct loss blocks for network training
|
||||
Note: There is a one-to-one mapping between output_heads and losses
|
||||
:return: list of loss blocks
|
||||
"""
|
||||
return [h.loss() for net in self.nets for h in net.output_heads]
|
||||
14
rl_coach/architectures/mxnet_components/heads/__init__.py
Normal file
14
rl_coach/architectures/mxnet_components/heads/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .head import Head, HeadLoss
|
||||
from .q_head import QHead
|
||||
from .ppo_head import PPOHead
|
||||
from .ppo_v_head import PPOVHead
|
||||
from .v_head import VHead
|
||||
|
||||
__all__ = [
|
||||
'Head',
|
||||
'HeadLoss',
|
||||
'QHead',
|
||||
'PPOHead',
|
||||
'PPOVHead',
|
||||
'VHead'
|
||||
]
|
||||
181
rl_coach/architectures/mxnet_components/heads/head.py
Normal file
181
rl_coach/architectures/mxnet_components/heads/head.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from typing import Dict, List, Union, Tuple
|
||||
|
||||
from mxnet.gluon import nn, loss
|
||||
from mxnet.ndarray import NDArray
|
||||
from mxnet.symbol import Symbol
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
LOSS_OUT_TYPE_LOSS = 'loss'
|
||||
LOSS_OUT_TYPE_REGULARIZATION = 'regularization'
|
||||
|
||||
|
||||
class LossInputSchema(object):
|
||||
"""
|
||||
Helper class to contain schema for loss hybrid_forward input
|
||||
"""
|
||||
def __init__(self, head_outputs: List[str], agent_inputs: List[str], targets: List[str]):
|
||||
"""
|
||||
:param head_outputs: list of argument names in hybrid_forward that are outputs of the head.
|
||||
The order and number MUST MATCH the output from the head.
|
||||
:param agent_inputs: list of argument names in hybrid_forward that are inputs from the agent.
|
||||
The order and number MUST MATCH `output_<head_type_idx>_<order>` for this head.
|
||||
:param targets: list of argument names in hybrid_forward that are targets for the loss.
|
||||
The order and number MUST MATCH targets passed from the agent.
|
||||
"""
|
||||
self._head_outputs = head_outputs
|
||||
self._agent_inputs = agent_inputs
|
||||
self._targets = targets
|
||||
|
||||
@property
|
||||
def head_outputs(self):
|
||||
return self._head_outputs
|
||||
|
||||
@property
|
||||
def agent_inputs(self):
|
||||
return self._agent_inputs
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return self._targets
|
||||
|
||||
|
||||
class HeadLoss(loss.Loss):
|
||||
"""
|
||||
ABC for loss functions of each head. Child class must implement input_schema() and loss_forward()
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(HeadLoss, self).__init__(*args, **kwargs)
|
||||
self._output_schema = None # type: List[str]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
"""
|
||||
:return: schema for input of hybrid_forward. Read docstring for LossInputSchema for details.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_schema(self) -> List[str]:
|
||||
"""
|
||||
:return: schema for output of hybrid_forward. Must contain 'loss' and 'regularization' keys at least once.
|
||||
The order and total number must match that of returned values from the loss. 'loss' and 'regularization'
|
||||
are special keys. Any other string is treated as auxiliary outputs and must include match auxiliary
|
||||
fetch names returned by the head.
|
||||
"""
|
||||
return self._output_schema
|
||||
|
||||
def forward(self, *args):
|
||||
"""
|
||||
Override forward() so that number of outputs can be checked against the schema
|
||||
"""
|
||||
outputs = super(HeadLoss, self).forward(*args)
|
||||
if isinstance(outputs, tuple) or isinstance(outputs, list):
|
||||
num_outputs = len(outputs)
|
||||
else:
|
||||
assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol)
|
||||
num_outputs = 1
|
||||
assert num_outputs == len(self.output_schema), "Number of outputs don't match schema ({} != {})".format(
|
||||
num_outputs, len(self.output_schema))
|
||||
return outputs
|
||||
|
||||
def _loss_output(self, outputs: List[Tuple[Union[NDArray, Symbol], str]]):
|
||||
"""
|
||||
Must be called on the output from hybrid_forward().
|
||||
Saves the returned output as the schema and returns output values in a list
|
||||
:return: list of output values
|
||||
"""
|
||||
output_schema = [o[1] for o in outputs]
|
||||
assert self._output_schema is None or self._output_schema == output_schema
|
||||
self._output_schema = output_schema
|
||||
return tuple(o[0] for o in outputs)
|
||||
|
||||
def hybrid_forward(self, F, x, *args, **kwargs):
|
||||
"""
|
||||
Passes the cal to loss_forward() and constructs output schema from its output by calling loss_output()
|
||||
"""
|
||||
return self._loss_output(self.loss_forward(F, x, *args, **kwargs))
|
||||
|
||||
def loss_forward(self, F, x, *args, **kwargs) -> List[Tuple[Union[NDArray, Symbol], str]]:
|
||||
"""
|
||||
Similar to hybrid_forward, but returns list of (NDArray, type_str)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Head(nn.HybridBlock):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition,
|
||||
network_name: str, head_type_idx: int=0, loss_weight: float=1., is_local: bool=True,
|
||||
activation_function: str='relu', dense_layer: None=None):
|
||||
"""
|
||||
A head is the final part of the network. It takes the embedding from the middleware embedder and passes it
|
||||
through a neural network to produce the output of the network. There can be multiple heads in a network, and
|
||||
each one has an assigned loss function. The heads are algorithm dependent.
|
||||
|
||||
:param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon
|
||||
and beta_entropy.
|
||||
:param spaces: containing action spaces used for defining size of network output.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
"""
|
||||
super(Head, self).__init__()
|
||||
self.head_type_idx = head_type_idx
|
||||
self.network_name = network_name
|
||||
self.loss_weight = loss_weight
|
||||
self.is_local = is_local
|
||||
self.ap = agent_parameters
|
||||
self.spaces = spaces
|
||||
self.return_type = None
|
||||
self.activation_function = activation_function
|
||||
self.dense_layer = dense_layer
|
||||
self._num_outputs = None
|
||||
|
||||
def loss(self) -> HeadLoss:
|
||||
"""
|
||||
Returns loss block to be used for specific head implementation.
|
||||
|
||||
:return: loss block (can be called as function) for outputs returned by the head network.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
""" Returns number of outputs that forward() call will return
|
||||
|
||||
:return:
|
||||
"""
|
||||
assert self._num_outputs is not None, 'must call forward() once to configure number of outputs'
|
||||
return self._num_outputs
|
||||
|
||||
def forward(self, *args):
|
||||
"""
|
||||
Override forward() so that number of outputs can be automatically set
|
||||
"""
|
||||
outputs = super(Head, self).forward(*args)
|
||||
if isinstance(outputs, tuple):
|
||||
num_outputs = len(outputs)
|
||||
else:
|
||||
assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol)
|
||||
num_outputs = 1
|
||||
if self._num_outputs is None:
|
||||
self._num_outputs = num_outputs
|
||||
else:
|
||||
assert self._num_outputs == num_outputs, 'Number of outputs cannot change ({} != {})'.format(
|
||||
self._num_outputs, num_outputs)
|
||||
assert self._num_outputs == len(self.loss().input_schema.head_outputs)
|
||||
return outputs
|
||||
|
||||
def hybrid_forward(self, F, x, *args, **kwargs):
|
||||
"""
|
||||
Used for forward pass through head network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final output of network, that will be used in loss calculations.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
669
rl_coach/architectures/mxnet_components/heads/ppo_head.py
Normal file
669
rl_coach/architectures/mxnet_components/heads/ppo_head.py
Normal file
@@ -0,0 +1,669 @@
|
||||
from typing import List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import math
|
||||
import mxnet as mx
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
|
||||
from rl_coach.utils import eps
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||
from rl_coach.architectures.mxnet_components.utils import hybrid_clip
|
||||
|
||||
|
||||
LOSS_OUT_TYPE_KL = 'kl_divergence'
|
||||
LOSS_OUT_TYPE_ENTROPY = 'entropy'
|
||||
LOSS_OUT_TYPE_LIKELIHOOD_RATIO = 'likelihood_ratio'
|
||||
LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO = 'clipped_likelihood_ratio'
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class MultivariateNormalDist:
|
||||
def __init__(self,
|
||||
num_var: int,
|
||||
mean: nd_sym_type,
|
||||
sigma: nd_sym_type,
|
||||
F: ModuleType=mx.nd) -> None:
|
||||
"""
|
||||
Distribution object for Multivariate Normal. Works with batches.
|
||||
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
|
||||
mean, sigma and data for log_prob must all include a time_step dimension.
|
||||
|
||||
:param num_var: number of variables in distribution
|
||||
:param mean: mean for each variable,
|
||||
of shape (num_var) or
|
||||
of shape (batch_size, num_var) or
|
||||
of shape (batch_size, time_step, num_var).
|
||||
:param sigma: covariance matrix,
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
"""
|
||||
self.num_var = num_var
|
||||
self.mean = mean
|
||||
self.sigma = sigma
|
||||
self.F = F
|
||||
|
||||
def inverse_using_cholesky(self, matrix: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Calculate inverses for a batch of matrices using Cholesky decomposition method.
|
||||
|
||||
:param matrix: matrix (or matrices) to invert,
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
:return: inverted matrix (or matrices),
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
"""
|
||||
cholesky_factor = self.F.linalg.potrf(matrix)
|
||||
return self.F.linalg.potri(cholesky_factor)
|
||||
|
||||
def log_det(self, matrix: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Calculate log of the determinant for a batch of matrices using Cholesky decomposition method.
|
||||
|
||||
:param matrix: matrix (or matrices) to invert,
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
:return: inverted matrix (or matrices),
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
"""
|
||||
cholesky_factor = self.F.linalg.potrf(matrix)
|
||||
return 2 * self.F.linalg.sumlogdiag(cholesky_factor)
|
||||
|
||||
def log_prob(self, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Calculate the log probability of data given the current distribution.
|
||||
|
||||
See http://www.notenoughthoughts.net/posts/normal-log-likelihood-gradient.html
|
||||
and https://discuss.mxnet.io/t/multivariate-gaussian-log-density-operator/1169/7
|
||||
|
||||
:param x: input data,
|
||||
of shape (num_var) or
|
||||
of shape (batch_size, num_var) or
|
||||
of shape (batch_size, time_step, num_var).
|
||||
:return: log_probability,
|
||||
of shape (1) or
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
"""
|
||||
a = (self.num_var / 2) * math.log(2 * math.pi)
|
||||
log_det_sigma = self.log_det(self.sigma)
|
||||
b = (1 / 2) * log_det_sigma
|
||||
sigma_inv = self.inverse_using_cholesky(self.sigma)
|
||||
# deviation from mean, and dev_t is equivalent to transpose on last two dims.
|
||||
dev = (x - self.mean).expand_dims(-1)
|
||||
dev_t = (x - self.mean).expand_dims(-2)
|
||||
|
||||
# since batch_dot only works with ndarrays with ndim of 3,
|
||||
# and we could have ndarrays with ndim of 4,
|
||||
# we flatten batch_size and time_step into single dim.
|
||||
dev_flat = dev.reshape(shape=(-1, 0, 0), reverse=1)
|
||||
sigma_inv_flat = sigma_inv.reshape(shape=(-1, 0, 0), reverse=1)
|
||||
dev_t_flat = dev_t.reshape(shape=(-1, 0, 0), reverse=1)
|
||||
c = (1 / 2) * self.F.batch_dot(self.F.batch_dot(dev_t_flat, sigma_inv_flat), dev_flat)
|
||||
# and now reshape back to (batch_size, time_step) if required.
|
||||
c = c.reshape_like(b)
|
||||
|
||||
log_likelihood = -a - b - c
|
||||
return log_likelihood
|
||||
|
||||
def entropy(self) -> nd_sym_type:
|
||||
"""
|
||||
Calculate entropy of current distribution.
|
||||
|
||||
See http://www.nowozin.net/sebastian/blog/the-entropy-of-a-normal-distribution.html
|
||||
:return: entropy,
|
||||
of shape (1) or
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
"""
|
||||
# todo: check if differential entropy is correct
|
||||
log_det_sigma = self.log_det(self.sigma)
|
||||
return (self.num_var / 2) + ((self.num_var / 2) * math.log(2 * math.pi)) + ((1 / 2) * log_det_sigma)
|
||||
|
||||
def kl_div(self, alt_dist) -> nd_sym_type:
|
||||
"""
|
||||
Calculated KL-Divergence with another MultivariateNormalDist distribution
|
||||
See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
|
||||
Specifically https://wikimedia.org/api/rest_v1/media/math/render/svg/a3bf3b4917bd1fcb8be48d6d6139e2e387bdc7d3
|
||||
|
||||
:param alt_dist: alternative distribution used for kl divergence calculation
|
||||
:type alt_dist: MultivariateNormalDist
|
||||
:return: KL-Divergence, of shape (1,)
|
||||
"""
|
||||
sigma_a_inv = self.F.linalg.potri(self.F.linalg.potrf(self.sigma))
|
||||
sigma_b_inv = self.F.linalg.potri(self.F.linalg.potrf(alt_dist.sigma))
|
||||
term1a = mx.nd.batch_dot(sigma_b_inv, self.sigma)
|
||||
# sum of diagonal for batch of matrices
|
||||
term1 = (self.F.eye(self.num_var).broadcast_like(term1a) * term1a).sum(axis=-1).sum(axis=-1)
|
||||
mean_diff = (alt_dist.mean - self.mean).expand_dims(-1)
|
||||
mean_diff_t = (alt_dist.mean - self.mean).expand_dims(-2)
|
||||
term2 = self.F.batch_dot(self.F.batch_dot(mean_diff_t, sigma_b_inv), mean_diff).reshape_like(term1)
|
||||
term3 = (2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(alt_dist.sigma))) -\
|
||||
(2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(self.sigma)))
|
||||
return 0.5 * (term1 + term2 - self.num_var + term3)
|
||||
|
||||
|
||||
|
||||
class CategoricalDist:
|
||||
def __init__(self, n_classes: int, probs: nd_sym_type, F: ModuleType=mx.nd) -> None:
|
||||
"""
|
||||
Distribution object for Categorical data.
|
||||
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
|
||||
mean, sigma and data for log_prob must all include a time_step dimension.
|
||||
|
||||
:param n_classes: number of classes in distribution
|
||||
:param probs: probabilities for each class,
|
||||
of shape (n_classes),
|
||||
of shape (batch_size, n_classes) or
|
||||
of shape (batch_size, time_step, n_classes)
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
"""
|
||||
self.n_classes = n_classes
|
||||
self.probs = probs
|
||||
self.F = F
|
||||
|
||||
|
||||
def log_prob(self, actions: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Calculate the log probability of data given the current distribution.
|
||||
|
||||
:param actions: actions, with int8 data type,
|
||||
of shape (1) if probs was (n_classes),
|
||||
of shape (batch_size) if probs was (batch_size, n_classes) and
|
||||
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
|
||||
:return: log_probability,
|
||||
of shape (1) if probs was (n_classes),
|
||||
of shape (batch_size) if probs was (batch_size, n_classes) and
|
||||
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
|
||||
"""
|
||||
action_mask = actions.one_hot(depth=self.n_classes)
|
||||
action_probs = (self.probs * action_mask).sum(axis=-1)
|
||||
return action_probs.log()
|
||||
|
||||
def entropy(self) -> nd_sym_type:
|
||||
"""
|
||||
Calculate entropy of current distribution.
|
||||
|
||||
:return: entropy,
|
||||
of shape (1) if probs was (n_classes),
|
||||
of shape (batch_size) if probs was (batch_size, n_classes) and
|
||||
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
|
||||
"""
|
||||
# todo: look into numerical stability
|
||||
return -(self.probs.log()*self.probs).sum(axis=-1)
|
||||
|
||||
def kl_div(self, alt_dist) -> nd_sym_type:
|
||||
"""
|
||||
Calculated KL-Divergence with another Categorical distribution
|
||||
|
||||
:param alt_dist: alternative distribution used for kl divergence calculation
|
||||
:type alt_dist: CategoricalDist
|
||||
:return: KL-Divergence
|
||||
"""
|
||||
logits_a = self.probs.clip(a_min=eps, a_max=1 - eps).log()
|
||||
logits_b = alt_dist.probs.clip(a_min=eps, a_max=1 - eps).log()
|
||||
t = self.probs * (logits_a - logits_b)
|
||||
t = self.F.where(condition=(alt_dist.probs == 0), x=self.F.ones_like(alt_dist.probs) * math.inf, y=t)
|
||||
t = self.F.where(condition=(self.probs == 0), x=self.F.zeros_like(self.probs), y=t)
|
||||
return t.sum(axis=-1)
|
||||
|
||||
|
||||
class DiscretePPOHead(nn.HybridBlock):
|
||||
def __init__(self, num_actions: int) -> None:
|
||||
"""
|
||||
Head block for Discrete Proximal Policy Optimization, to calculate probabilities for each action given
|
||||
middleware representation of the environment state.
|
||||
|
||||
:param num_actions: number of actions in action space.
|
||||
"""
|
||||
super(DiscretePPOHead, self).__init__()
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False)
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through head network.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param x: middleware state representation,
|
||||
of shape (batch_size, in_channels) or
|
||||
of shape (batch_size, time_step, in_channels).
|
||||
:return: batch of probabilities for each action,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
"""
|
||||
policy_values = self.dense(x)
|
||||
policy_probs = F.softmax(policy_values)
|
||||
return policy_probs
|
||||
|
||||
|
||||
class ContinuousPPOHead(nn.HybridBlock):
|
||||
def __init__(self, num_actions: int) -> None:
|
||||
"""
|
||||
Head block for Continuous Proximal Policy Optimization, to calculate probabilities for each action given
|
||||
middleware representation of the environment state.
|
||||
|
||||
:param num_actions: number of actions in action space.
|
||||
"""
|
||||
super(ContinuousPPOHead, self).__init__()
|
||||
with self.name_scope():
|
||||
# todo: change initialization strategy
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False)
|
||||
# all samples (across batch, and time step) share the same covariance, which is learnt,
|
||||
# but since we assume the action probability variables are independent,
|
||||
# only the diagonal entries of the covariance matrix are specified.
|
||||
self.log_std = self.params.get('log_std',
|
||||
shape=num_actions,
|
||||
init=mx.init.Zero(),
|
||||
allow_deferred_init=True)
|
||||
# todo: is_local?
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, log_std: nd_sym_type) -> List[nd_sym_type]:
|
||||
"""
|
||||
Used for forward pass through head network.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param x: middleware state representation,
|
||||
of shape (batch_size, in_channels) or
|
||||
of shape (batch_size, time_step, in_channels).
|
||||
:return: batch of probabilities for each action,
|
||||
of shape (batch_size, action_mean) or
|
||||
of shape (batch_size, time_step, action_mean).
|
||||
"""
|
||||
policy_means = self.dense(x)
|
||||
policy_std = log_std.exp()
|
||||
return [policy_means, policy_std]
|
||||
|
||||
|
||||
class ClippedPPOLossDiscrete(HeadLoss):
|
||||
def __init__(self,
|
||||
num_actions: int,
|
||||
clip_likelihood_ratio_using_epsilon: float,
|
||||
beta: float=0,
|
||||
use_kl_regularization: bool=False,
|
||||
initial_kl_coefficient: float=1,
|
||||
kl_cutoff: float=0,
|
||||
high_kl_penalty_coefficient: float=1,
|
||||
weight: float=1,
|
||||
batch_axis: int=0) -> None:
|
||||
"""
|
||||
Loss for discrete version of Clipped PPO.
|
||||
|
||||
:param num_actions: number of actions in action space.
|
||||
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
|
||||
:param beta: loss coefficient applied to entropy
|
||||
:param use_kl_regularization: option to add kl divergence loss
|
||||
:param initial_kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
||||
:param kl_cutoff: threshold for using high_kl_penalty_coefficient
|
||||
:param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(ClippedPPOLossDiscrete, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
self.weight = weight
|
||||
self.num_actions = num_actions
|
||||
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
|
||||
self.beta = beta
|
||||
self.use_kl_regularization = use_kl_regularization
|
||||
self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0
|
||||
self.kl_coefficient = self.params.get('kl_coefficient',
|
||||
shape=(1,),
|
||||
init=mx.init.Constant([initial_kl_coefficient,]),
|
||||
differentiable=False)
|
||||
self.kl_cutoff = kl_cutoff
|
||||
self.high_kl_penalty_coefficient = high_kl_penalty_coefficient
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['new_policy_probs'],
|
||||
agent_inputs=['actions', 'old_policy_probs', 'clip_param_rescaler'],
|
||||
targets=['advantages']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
new_policy_probs: nd_sym_type,
|
||||
actions: nd_sym_type,
|
||||
old_policy_probs: nd_sym_type,
|
||||
clip_param_rescaler: nd_sym_type,
|
||||
advantages: nd_sym_type,
|
||||
kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
|
||||
new_policy_probs, old_policy_probs, actions and advantages all must include a time_step dimension.
|
||||
|
||||
NOTE: order of input arguments MUST NOT CHANGE because it matches the order
|
||||
parameters are passed in ppo_agent:train_network()
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param new_policy_probs: action probabilities predicted by DiscretePPOHead network,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param old_policy_probs: action probabilities for previous policy,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param actions: true actions taken during rollout,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
|
||||
:param advantages: change in state value after taking action (a.k.a advantage)
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
|
||||
old_policy_dist = CategoricalDist(self.num_actions, old_policy_probs, F=F)
|
||||
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
|
||||
|
||||
new_policy_dist = CategoricalDist(self.num_actions, new_policy_probs, F=F)
|
||||
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
|
||||
|
||||
entropy_loss = - self.beta * new_policy_dist.entropy().mean()
|
||||
|
||||
if self.use_kl_regularization:
|
||||
kl_div = old_policy_dist.kl_div(new_policy_dist).mean()
|
||||
weighted_kl_div = kl_coefficient * kl_div
|
||||
high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square()
|
||||
weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div
|
||||
kl_div_loss = weighted_kl_div + weighted_high_kl_div
|
||||
else:
|
||||
kl_div_loss = F.zeros(shape=(1,))
|
||||
|
||||
# working with log probs, so minus first, then exponential (same as division)
|
||||
likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp()
|
||||
|
||||
if self.clip_likelihood_ratio_using_epsilon is not None:
|
||||
# clipping of likelihood ratio
|
||||
min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
|
||||
max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
|
||||
|
||||
# can't use F.clip (with variable clipping bounds), hence custom implementation
|
||||
clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value)
|
||||
|
||||
# lower bound of original, and clipped versions or each scaled advantage
|
||||
# element-wise min between the two ndarrays
|
||||
unclipped_scaled_advantages = likelihood_ratio * advantages
|
||||
clipped_scaled_advantages = clipped_likelihood_ratio * advantages
|
||||
scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0)
|
||||
else:
|
||||
scaled_advantages = likelihood_ratio * advantages
|
||||
clipped_likelihood_ratio = F.zeros_like(likelihood_ratio)
|
||||
|
||||
# for each batch, calculate expectation of scaled_advantages across time steps,
|
||||
# but want code to work with data without time step too, so reshape to add timestep if doesn't exist.
|
||||
scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1))
|
||||
expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1)
|
||||
# want to maximize expected_scaled_advantages, add minus so can minimize.
|
||||
surrogate_loss = (-expected_scaled_advantages * self.weight).mean()
|
||||
|
||||
return [
|
||||
(surrogate_loss, LOSS_OUT_TYPE_LOSS),
|
||||
(entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION),
|
||||
(kl_div_loss, LOSS_OUT_TYPE_KL),
|
||||
(entropy_loss, LOSS_OUT_TYPE_ENTROPY),
|
||||
(likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO),
|
||||
(clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO)
|
||||
]
|
||||
|
||||
|
||||
class ClippedPPOLossContinuous(HeadLoss):
|
||||
def __init__(self,
|
||||
num_actions: int,
|
||||
clip_likelihood_ratio_using_epsilon: float,
|
||||
beta: float=0,
|
||||
use_kl_regularization: bool=False,
|
||||
initial_kl_coefficient: float=1,
|
||||
kl_cutoff: float=0,
|
||||
high_kl_penalty_coefficient: float=1,
|
||||
weight: float=1,
|
||||
batch_axis: int=0):
|
||||
"""
|
||||
Loss for continuous version of Clipped PPO.
|
||||
|
||||
:param num_actions: number of actions in action space.
|
||||
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
|
||||
:param beta: loss coefficient applied to entropy
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
:param use_kl_regularization: option to add kl divergence loss
|
||||
:param initial_kl_coefficient: initial loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
||||
:param kl_cutoff: threshold for using high_kl_penalty_coefficient
|
||||
:param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(ClippedPPOLossContinuous, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
self.weight = weight
|
||||
self.num_actions = num_actions
|
||||
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
|
||||
self.beta = beta
|
||||
self.use_kl_regularization = use_kl_regularization
|
||||
self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0
|
||||
self.kl_coefficient = self.params.get('kl_coefficient',
|
||||
shape=(1,),
|
||||
init=mx.init.Constant([initial_kl_coefficient,]),
|
||||
differentiable=False)
|
||||
self.kl_cutoff = kl_cutoff
|
||||
self.high_kl_penalty_coefficient = high_kl_penalty_coefficient
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['new_policy_means','new_policy_stds'],
|
||||
agent_inputs=['actions', 'old_policy_means', 'old_policy_stds', 'clip_param_rescaler'],
|
||||
targets=['advantages']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
new_policy_means: nd_sym_type,
|
||||
new_policy_stds: nd_sym_type,
|
||||
actions: nd_sym_type,
|
||||
old_policy_means: nd_sym_type,
|
||||
old_policy_stds: nd_sym_type,
|
||||
clip_param_rescaler: nd_sym_type,
|
||||
advantages: nd_sym_type,
|
||||
kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
|
||||
new_policy_means, old_policy_means, actions and advantages all must include a time_step dimension.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param new_policy_means: action means predicted by MultivariateNormalDist network,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param new_policy_stds: action standard deviation returned by head,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param actions: true actions taken during rollout,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param old_policy_means: action means for previous policy,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param old_policy_stds: action standard deviation returned by head previously,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
|
||||
:param advantages: change in state value after taking action (a.k.a advantage)
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
old_var = old_policy_stds ** 2
|
||||
# sets diagonal in (batch size and time step) covariance matrices
|
||||
old_covar = mx.nd.eye(N=self.num_actions) * (old_var + eps).broadcast_like(old_policy_means).expand_dims(-2)
|
||||
old_policy_dist = MultivariateNormalDist(self.num_actions, old_policy_means, old_covar, F=F)
|
||||
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
|
||||
|
||||
new_var = new_policy_stds ** 2
|
||||
# sets diagonal in (batch size and time step) covariance matrices
|
||||
new_covar = mx.nd.eye(N=self.num_actions) * (new_var + eps).broadcast_like(new_policy_means).expand_dims(-2)
|
||||
new_policy_dist = MultivariateNormalDist(self.num_actions, new_policy_means, new_covar, F=F)
|
||||
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
|
||||
|
||||
entropy_loss = - self.beta * new_policy_dist.entropy().mean()
|
||||
|
||||
if self.use_kl_regularization:
|
||||
kl_div = old_policy_dist.kl_div(new_policy_dist).mean()
|
||||
weighted_kl_div = kl_coefficient * kl_div
|
||||
high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square()
|
||||
weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div
|
||||
kl_div_loss = weighted_kl_div + weighted_high_kl_div
|
||||
else:
|
||||
kl_div_loss = F.zeros(shape=(1,))
|
||||
|
||||
# working with log probs, so minus first, then exponential (same as division)
|
||||
likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp()
|
||||
|
||||
if self.clip_likelihood_ratio_using_epsilon is not None:
|
||||
# clipping of likelihood ratio
|
||||
min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
|
||||
max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
|
||||
|
||||
# can't use F.clip (with variable clipping bounds), hence custom implementation
|
||||
clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value)
|
||||
|
||||
# lower bound of original, and clipped versions or each scaled advantage
|
||||
# element-wise min between the two ndarrays
|
||||
unclipped_scaled_advantages = likelihood_ratio * advantages
|
||||
clipped_scaled_advantages = clipped_likelihood_ratio * advantages
|
||||
scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0)
|
||||
else:
|
||||
scaled_advantages = likelihood_ratio * advantages
|
||||
clipped_likelihood_ratio = F.zeros_like(likelihood_ratio)
|
||||
|
||||
# for each batch, calculate expectation of scaled_advantages across time steps,
|
||||
# but want code to work with data without time step too, so reshape to add timestep if doesn't exist.
|
||||
scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1))
|
||||
expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1)
|
||||
# want to maximize expected_scaled_advantages, add minus so can minimize.
|
||||
surrogate_loss = (-expected_scaled_advantages * self.weight).mean()
|
||||
|
||||
return [
|
||||
(surrogate_loss, LOSS_OUT_TYPE_LOSS),
|
||||
(entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION),
|
||||
(kl_div_loss, LOSS_OUT_TYPE_KL),
|
||||
(entropy_loss, LOSS_OUT_TYPE_ENTROPY),
|
||||
(likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO),
|
||||
(clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO)
|
||||
]
|
||||
|
||||
|
||||
class PPOHead(Head):
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
network_name: str,
|
||||
head_type_idx: int=0,
|
||||
loss_weight: float=1.,
|
||||
is_local: bool=True,
|
||||
activation_function: str='tanh',
|
||||
dense_layer: None=None) -> None:
|
||||
"""
|
||||
Head block for Proximal Policy Optimization, to calculate probabilities for each action given middleware
|
||||
representation of the environment state.
|
||||
|
||||
:param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon
|
||||
and beta_entropy.
|
||||
:param spaces: containing action spaces used for defining size of network output.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
"""
|
||||
super().__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.return_type = ActionProbabilities
|
||||
|
||||
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
|
||||
self.beta = agent_parameters.algorithm.beta_entropy
|
||||
self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization
|
||||
if self.use_kl_regularization:
|
||||
self.initial_kl_coefficient = agent_parameters.algorithm.initial_kl_coefficient
|
||||
self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence
|
||||
self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient
|
||||
else:
|
||||
self.initial_kl_coefficient, self.kl_cutoff, self.high_kl_penalty_coefficient = (None, None, None)
|
||||
self._loss = []
|
||||
|
||||
if isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
self.net = DiscretePPOHead(num_actions=len(self.spaces.action.actions))
|
||||
elif isinstance(self.spaces.action, BoxActionSpace):
|
||||
self.net = ContinuousPPOHead(num_actions=len(self.spaces.action.actions))
|
||||
else:
|
||||
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
|
||||
|
||||
def hybrid_forward(self,
|
||||
F: ModuleType,
|
||||
x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param x: middleware embedding
|
||||
:return: policy parameters/probabilities
|
||||
"""
|
||||
return self.net(x)
|
||||
|
||||
def loss(self) -> mx.gluon.loss.Loss:
|
||||
"""
|
||||
Specifies loss block to be used for this policy head.
|
||||
|
||||
:return: loss block (can be called as function) for action probabilities returned by this policy network.
|
||||
"""
|
||||
if isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
loss = ClippedPPOLossDiscrete(len(self.spaces.action.actions),
|
||||
self.clip_likelihood_ratio_using_epsilon,
|
||||
self.beta,
|
||||
self.use_kl_regularization, self.initial_kl_coefficient,
|
||||
self.kl_cutoff, self.high_kl_penalty_coefficient,
|
||||
self.loss_weight)
|
||||
elif isinstance(self.spaces.action, BoxActionSpace):
|
||||
loss = ClippedPPOLossContinuous(len(self.spaces.action.actions),
|
||||
self.clip_likelihood_ratio_using_epsilon,
|
||||
self.beta,
|
||||
self.use_kl_regularization, self.initial_kl_coefficient,
|
||||
self.kl_cutoff, self.high_kl_penalty_coefficient,
|
||||
self.loss_weight)
|
||||
else:
|
||||
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
|
||||
loss.initialize()
|
||||
# set a property so can assign_kl_coefficient in future,
|
||||
# make a list, otherwise it would be added as a child of Head Block (due to type check)
|
||||
self._loss = [loss]
|
||||
return loss
|
||||
|
||||
@property
|
||||
def kl_divergence(self):
|
||||
return self.head_type_idx, LOSS_OUT_TYPE_KL
|
||||
|
||||
@property
|
||||
def entropy(self):
|
||||
return self.head_type_idx, LOSS_OUT_TYPE_ENTROPY
|
||||
|
||||
@property
|
||||
def likelihood_ratio(self):
|
||||
return self.head_type_idx, LOSS_OUT_TYPE_LIKELIHOOD_RATIO
|
||||
|
||||
@property
|
||||
def clipped_likelihood_ratio(self):
|
||||
return self.head_type_idx, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO
|
||||
|
||||
def assign_kl_coefficient(self, kl_coefficient: float) -> None:
|
||||
self._loss[0].kl_coefficient.set_data(mx.nd.array((kl_coefficient,)))
|
||||
123
rl_coach/architectures/mxnet_components/heads/ppo_v_head.py
Normal file
123
rl_coach/architectures/mxnet_components/heads/ppo_v_head.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class PPOVHeadLoss(HeadLoss):
|
||||
def __init__(self, clip_likelihood_ratio_using_epsilon: float, weight: float=1, batch_axis: int=0) -> None:
|
||||
"""
|
||||
Loss for PPO Value network.
|
||||
Schulman implemented this extension in OpenAI baselines for PPO2
|
||||
See https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py#L72
|
||||
|
||||
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(PPOVHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
self.weight = weight
|
||||
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['new_policy_values'],
|
||||
agent_inputs=['old_policy_values'],
|
||||
targets=['target_values']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
new_policy_values: nd_sym_type,
|
||||
old_policy_values: nd_sym_type,
|
||||
target_values: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
Calculates two losses (L2 and a clipped difference L2 loss) and takes the maximum of the two.
|
||||
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
|
||||
new_policy_values, old_policy_values and target_values all must include a time_step dimension.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param new_policy_values: values predicted by PPOVHead network,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param old_policy_values: values predicted by old value network,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param target_values: actual state values,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
# L2 loss
|
||||
value_loss_1 = (new_policy_values - target_values).square()
|
||||
# Clipped difference L2 loss
|
||||
diff = new_policy_values - old_policy_values
|
||||
clipped_diff = diff.clip(a_min=-self.clip_likelihood_ratio_using_epsilon,
|
||||
a_max=self.clip_likelihood_ratio_using_epsilon)
|
||||
value_loss_2 = (old_policy_values + clipped_diff - target_values).square()
|
||||
# Maximum of the two losses, element-wise maximum.
|
||||
value_loss_max = mx.nd.stack(value_loss_1, value_loss_2).max(axis=0)
|
||||
# Aggregate over temporal axis, adding if doesn't exist (hense reshape)
|
||||
value_loss_max_w_time = value_loss_max.reshape(shape=(0, -1))
|
||||
value_loss = value_loss_max_w_time.mean(axis=1)
|
||||
# Weight the loss (and average over samples of batch)
|
||||
value_loss_weighted = value_loss.mean() * self.weight
|
||||
return [(value_loss_weighted, LOSS_OUT_TYPE_LOSS)]
|
||||
|
||||
|
||||
class PPOVHead(Head):
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
network_name: str,
|
||||
head_type_idx: int=0,
|
||||
loss_weight: float=1.,
|
||||
is_local: bool = True,
|
||||
activation_function: str='relu',
|
||||
dense_layer: None=None) -> None:
|
||||
"""
|
||||
PPO Value Head for predicting state values.
|
||||
|
||||
:param agent_parameters: containing algorithm parameters, but currently unused.
|
||||
:param spaces: containing action spaces, but currently unused.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
"""
|
||||
super(PPOVHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local,
|
||||
activation_function, dense_layer=dense_layer)
|
||||
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
|
||||
self.return_type = ActionProbabilities
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=1)
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through value head network.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final value output of network, of shape (batch_size).
|
||||
"""
|
||||
return self.dense(x).squeeze()
|
||||
|
||||
def loss(self) -> mx.gluon.loss.Loss:
|
||||
"""
|
||||
Specifies loss block to be used for specific value head implementation.
|
||||
|
||||
:return: loss block (can be called as function) for outputs returned by the value head network.
|
||||
"""
|
||||
return PPOVHeadLoss(self.clip_likelihood_ratio_using_epsilon, weight=self.loss_weight)
|
||||
106
rl_coach/architectures/mxnet_components/heads/q_head.py
Normal file
106
rl_coach/architectures/mxnet_components/heads/q_head.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import Union, List, Tuple
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class QHeadLoss(HeadLoss):
|
||||
def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None:
|
||||
"""
|
||||
Loss for Q-Value Head.
|
||||
|
||||
:param loss_type: loss function with default of mean squared error (i.e. L2Loss).
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(QHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
with self.name_scope():
|
||||
self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis)
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['pred'],
|
||||
agent_inputs=[],
|
||||
targets=['target']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
pred: nd_sym_type,
|
||||
target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param pred: state-action q-values predicted by QHead network, of shape (batch_size, num_actions).
|
||||
:param target: actual state-action q-values, of shape (batch_size, num_actions).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
loss = self.loss_fn(pred, target).mean()
|
||||
return [(loss, LOSS_OUT_TYPE_LOSS)]
|
||||
|
||||
|
||||
class QHead(Head):
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
network_name: str,
|
||||
head_type_idx: int=0,
|
||||
loss_weight: float=1.,
|
||||
is_local: bool=True,
|
||||
activation_function: str='relu',
|
||||
dense_layer: None=None,
|
||||
loss_type: Union[HuberLoss, L2Loss]=L2Loss) -> None:
|
||||
"""
|
||||
Q-Value Head for predicting state-action Q-Values.
|
||||
|
||||
:param agent_parameters: containing algorithm parameters, but currently unused.
|
||||
:param spaces: containing action spaces used for defining size of network output.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
:param loss_type: loss function to use.
|
||||
"""
|
||||
super(QHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight,
|
||||
is_local, activation_function, dense_layer)
|
||||
if isinstance(self.spaces.action, BoxActionSpace):
|
||||
self.num_actions = 1
|
||||
elif isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
self.num_actions = len(self.spaces.action.actions)
|
||||
self.return_type = QActionStateValue
|
||||
assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss."
|
||||
self.loss_type = loss_type
|
||||
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=self.num_actions)
|
||||
|
||||
def loss(self) -> Loss:
|
||||
"""
|
||||
Specifies loss block to be used for specific value head implementation.
|
||||
|
||||
:return: loss block (can be called as function) for outputs returned by the head network.
|
||||
"""
|
||||
return QHeadLoss(loss_type=self.loss_type, weight=self.loss_weight)
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through Q-Value head network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: predicted state-action q-values, of shape (batch_size, num_actions).
|
||||
"""
|
||||
return self.dense(x)
|
||||
100
rl_coach/architectures/mxnet_components/heads/v_head.py
Normal file
100
rl_coach/architectures/mxnet_components/heads/v_head.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from typing import Union, List, Tuple
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import VStateValue
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class VHeadLoss(HeadLoss):
|
||||
def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None:
|
||||
"""
|
||||
Loss for Value Head.
|
||||
|
||||
:param loss_type: loss function with default of mean squared error (i.e. L2Loss).
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(VHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
with self.name_scope():
|
||||
self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis)
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['pred'],
|
||||
agent_inputs=[],
|
||||
targets=['target']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
pred: nd_sym_type,
|
||||
target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param pred: state values predicted by VHead network, of shape (batch_size).
|
||||
:param target: actual state values, of shape (batch_size).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
loss = self.loss_fn(pred, target).mean()
|
||||
return [(loss, LOSS_OUT_TYPE_LOSS)]
|
||||
|
||||
|
||||
class VHead(Head):
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
network_name: str,
|
||||
head_type_idx: int=0,
|
||||
loss_weight: float=1.,
|
||||
is_local: bool=True,
|
||||
activation_function: str='relu',
|
||||
dense_layer: None=None,
|
||||
loss_type: Union[HuberLoss, L2Loss]=L2Loss):
|
||||
"""
|
||||
Value Head for predicting state values.
|
||||
:param agent_parameters: containing algorithm parameters, but currently unused.
|
||||
:param spaces: containing action spaces, but currently unused.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
:param loss_type: loss function with default of mean squared error (i.e. L2Loss), or alternatively HuberLoss.
|
||||
"""
|
||||
super(VHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight,
|
||||
is_local, activation_function, dense_layer)
|
||||
assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss."
|
||||
self.loss_type = loss_type
|
||||
self.return_type = VStateValue
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=1)
|
||||
|
||||
def loss(self) -> Loss:
|
||||
"""
|
||||
Specifies loss block to be used for specific value head implementation.
|
||||
|
||||
:return: loss block (can be called as function) for outputs returned by the head network.
|
||||
"""
|
||||
return VHeadLoss(loss_type=self.loss_type, weight=self.loss_weight)
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through value head network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final output of value network, of shape (batch_size).
|
||||
"""
|
||||
return self.dense(x).squeeze()
|
||||
99
rl_coach/architectures/mxnet_components/layers.py
Normal file
99
rl_coach/architectures/mxnet_components/layers.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Module implementing basic layers in mxnet
|
||||
"""
|
||||
|
||||
from types import FunctionType
|
||||
|
||||
from mxnet.gluon import nn
|
||||
|
||||
from rl_coach.architectures import layers
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
|
||||
|
||||
# define global dictionary for storing layer type to layer implementation mapping
|
||||
mx_layer_dict = dict()
|
||||
|
||||
|
||||
def reg_to_mx(layer_type) -> FunctionType:
|
||||
""" function decorator that registers layer implementation
|
||||
:return: decorated function
|
||||
"""
|
||||
def reg_impl_decorator(func):
|
||||
assert layer_type not in mx_layer_dict
|
||||
mx_layer_dict[layer_type] = func
|
||||
return func
|
||||
return reg_impl_decorator
|
||||
|
||||
|
||||
def convert_layer(layer):
|
||||
"""
|
||||
If layer is callable, return layer, otherwise convert to MX type
|
||||
:param layer: layer to be converted
|
||||
:return: converted layer if not callable, otherwise layer itself
|
||||
"""
|
||||
if callable(layer):
|
||||
return layer
|
||||
return mx_layer_dict[type(layer)](layer)
|
||||
|
||||
|
||||
class Conv2d(layers.Conv2d):
|
||||
def __init__(self, num_filters: int, kernel_size: int, strides: int):
|
||||
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
|
||||
|
||||
def __call__(self) -> nn.Conv2D:
|
||||
"""
|
||||
returns a conv2d block
|
||||
:return: conv2d block
|
||||
"""
|
||||
return nn.Conv2D(channels=self.num_filters, kernel_size=self.kernel_size, strides=self.strides)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_mx(layers.Conv2d)
|
||||
def to_mx(base: layers.Conv2d):
|
||||
return Conv2d(num_filters=base.num_filters, kernel_size=base.kernel_size, strides=base.strides)
|
||||
|
||||
|
||||
class BatchnormActivationDropout(layers.BatchnormActivationDropout):
|
||||
def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0):
|
||||
super(BatchnormActivationDropout, self).__init__(
|
||||
batchnorm=batchnorm, activation_function=activation_function, dropout_rate=dropout_rate)
|
||||
|
||||
def __call__(self):
|
||||
"""
|
||||
returns a list of mxnet batchnorm, activation and dropout layers
|
||||
:return: batchnorm, activation and dropout layers
|
||||
"""
|
||||
block = nn.HybridSequential()
|
||||
if self.batchnorm:
|
||||
block.add(nn.BatchNorm())
|
||||
if self.activation_function:
|
||||
block.add(nn.Activation(activation=utils.get_mxnet_activation_name(self.activation_function)))
|
||||
if self.dropout_rate:
|
||||
block.add(nn.Dropout(self.dropout_rate))
|
||||
return block
|
||||
|
||||
@staticmethod
|
||||
@reg_to_mx(layers.BatchnormActivationDropout)
|
||||
def to_mx(base: layers.BatchnormActivationDropout):
|
||||
return BatchnormActivationDropout(
|
||||
batchnorm=base.batchnorm,
|
||||
activation_function=base.activation_function,
|
||||
dropout_rate=base.dropout_rate)
|
||||
|
||||
|
||||
class Dense(layers.Dense):
|
||||
def __init__(self, units: int):
|
||||
super(Dense, self).__init__(units=units)
|
||||
|
||||
def __call__(self):
|
||||
"""
|
||||
returns a mxnet dense layer
|
||||
:return: dense layer
|
||||
"""
|
||||
# Set flatten to False for consistent behavior with tf.layers.dense
|
||||
return nn.Dense(self.units, flatten=False)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_mx(layers.Dense)
|
||||
def to_mx(base: layers.Dense):
|
||||
return Dense(units=base.units)
|
||||
@@ -0,0 +1,4 @@
|
||||
from .fc_middleware import FCMiddleware
|
||||
from .lstm_middleware import LSTMMiddleware
|
||||
|
||||
__all__ = ["FCMiddleware", "LSTMMiddleware"]
|
||||
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Module that defines the fully-connected middleware class
|
||||
"""
|
||||
|
||||
from rl_coach.architectures.mxnet_components.layers import Dense
|
||||
from rl_coach.architectures.mxnet_components.middlewares.middleware import Middleware
|
||||
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
|
||||
|
||||
class FCMiddleware(Middleware):
|
||||
def __init__(self, params: FCMiddlewareParameters):
|
||||
"""
|
||||
FCMiddleware or Fully-Connected Middleware can be used in the middle part of the network. It takes the
|
||||
embeddings from the input embedders, after they were aggregated in some method (for example, concatenation)
|
||||
and passes it through a neural network which can be customizable but shared between the heads of the network.
|
||||
|
||||
:param params: parameters object containing batchnorm, activation_function and dropout properties.
|
||||
"""
|
||||
super(FCMiddleware, self).__init__(params)
|
||||
|
||||
@property
|
||||
def schemes(self) -> dict:
|
||||
"""
|
||||
Schemes are the pre-defined network architectures of various depths and complexities that can be used for the
|
||||
Middleware. Are used to create Block when FCMiddleware is initialised.
|
||||
|
||||
:return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of mxnet.gluon.Block.
|
||||
"""
|
||||
return {
|
||||
MiddlewareScheme.Empty:
|
||||
[],
|
||||
|
||||
# Use for PPO
|
||||
MiddlewareScheme.Shallow:
|
||||
[
|
||||
Dense(units=64)
|
||||
],
|
||||
|
||||
# Use for DQN
|
||||
MiddlewareScheme.Medium:
|
||||
[
|
||||
Dense(units=512)
|
||||
],
|
||||
|
||||
MiddlewareScheme.Deep:
|
||||
[
|
||||
Dense(units=128),
|
||||
Dense(units=128),
|
||||
Dense(units=128)
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Module that defines the LSTM middleware class
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon import rnn
|
||||
from rl_coach.architectures.mxnet_components.layers import Dense
|
||||
from rl_coach.architectures.mxnet_components.middlewares.middleware import Middleware
|
||||
from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class LSTMMiddleware(Middleware):
|
||||
def __init__(self, params: LSTMMiddlewareParameters):
|
||||
"""
|
||||
LSTMMiddleware or Long Short Term Memory Middleware can be used in the middle part of the network. It takes the
|
||||
embeddings from the input embedders, after they were aggregated in some method (for example, concatenation)
|
||||
and passes it through a neural network which can be customizable but shared between the heads of the network.
|
||||
|
||||
:param params: parameters object containing batchnorm, activation_function, dropout and
|
||||
number_of_lstm_cells properties.
|
||||
"""
|
||||
super(LSTMMiddleware, self).__init__(params)
|
||||
self.number_of_lstm_cells = params.number_of_lstm_cells
|
||||
with self.name_scope():
|
||||
self.lstm = rnn.LSTM(hidden_size=self.number_of_lstm_cells)
|
||||
|
||||
@property
|
||||
def schemes(self) -> dict:
|
||||
"""
|
||||
Schemes are the pre-defined network architectures of various depths and complexities that can be used for the
|
||||
Middleware. Are used to create Block when LSTMMiddleware is initialised, and are applied before the LSTM.
|
||||
|
||||
:return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of mxnet.gluon.Block.
|
||||
"""
|
||||
return {
|
||||
MiddlewareScheme.Empty:
|
||||
[],
|
||||
|
||||
# Use for PPO
|
||||
MiddlewareScheme.Shallow:
|
||||
[
|
||||
Dense(units=64)
|
||||
],
|
||||
|
||||
# Use for DQN
|
||||
MiddlewareScheme.Medium:
|
||||
[
|
||||
Dense(units=512)
|
||||
],
|
||||
|
||||
MiddlewareScheme.Deep:
|
||||
[
|
||||
Dense(units=128),
|
||||
Dense(units=128),
|
||||
Dense(units=128)
|
||||
]
|
||||
}
|
||||
|
||||
def hybrid_forward(self,
|
||||
F: ModuleType,
|
||||
x: nd_sym_type,
|
||||
*args, **kwargs) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through LSTM middleware network.
|
||||
Applies dense layers from selected scheme before passing result to LSTM layer.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: state embedding, of shape (batch_size, in_channels).
|
||||
:return: state middleware embedding, where shape is (batch_size, channels).
|
||||
"""
|
||||
x_ntc = x.reshape(shape=(0, 0, -1))
|
||||
emb_ntc = super(LSTMMiddleware, self).hybrid_forward(F, x_ntc, *args, **kwargs)
|
||||
emb_tnc = emb_ntc.transpose(axes=(1, 0, 2))
|
||||
return self.lstm(emb_tnc)
|
||||
@@ -0,0 +1,61 @@
|
||||
from typing import Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
||||
from rl_coach.architectures.mxnet_components.layers import convert_layer
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class Middleware(nn.HybridBlock):
|
||||
def __init__(self, params: MiddlewareParameters):
|
||||
"""
|
||||
Middleware is the middle part of the network. It takes the embeddings from the input embedders,
|
||||
after they were aggregated in some method (for example, concatenation) and passes it through a neural network
|
||||
which can be customizable but shared between the heads of the network.
|
||||
|
||||
:param params: parameters object containing batchnorm, activation_function and dropout properties.
|
||||
"""
|
||||
super(Middleware, self).__init__()
|
||||
self.scheme = params.scheme
|
||||
|
||||
with self.name_scope():
|
||||
self.net = nn.HybridSequential()
|
||||
if isinstance(self.scheme, MiddlewareScheme):
|
||||
blocks = self.schemes[self.scheme]
|
||||
else:
|
||||
# if scheme is specified directly, convert to MX layer if it's not a callable object
|
||||
# NOTE: if layer object is callable, it must return a gluon block when invoked
|
||||
blocks = [convert_layer(l) for l in self.scheme]
|
||||
for block in blocks:
|
||||
self.net.add(block())
|
||||
if params.batchnorm:
|
||||
self.net.add(nn.BatchNorm())
|
||||
if params.activation_function:
|
||||
self.net.add(nn.Activation(params.activation_function))
|
||||
if params.dropout:
|
||||
self.net.add(nn.Dropout(rate=params.dropout))
|
||||
|
||||
@property
|
||||
def schemes(self) -> dict:
|
||||
"""
|
||||
Schemes are the pre-defined network architectures of various depths and complexities that can be used for the
|
||||
Middleware. Should be implemented in child classes, and are used to create Block when Middleware is initialised.
|
||||
|
||||
:return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of mxnet.gluon.Block.
|
||||
"""
|
||||
raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default "
|
||||
"configurations.")
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through middleware network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: state embedding, of shape (batch_size, in_channels).
|
||||
:return: state middleware embedding, where shape is (batch_size, channels).
|
||||
"""
|
||||
return self.net(x)
|
||||
280
rl_coach/architectures/mxnet_components/utils.py
Normal file
280
rl_coach/architectures/mxnet_components/utils.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Module defining utility functions
|
||||
"""
|
||||
import inspect
|
||||
from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet import nd
|
||||
from mxnet.ndarray import NDArray
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\
|
||||
Union[List[NDArray], Tuple[NDArray], NDArray]:
|
||||
"""
|
||||
Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or
|
||||
it can be np.ndarray, int, or float
|
||||
:param data: input data to be converted
|
||||
:return: converted output data
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
data = [to_mx_ndarray(d) for d in data]
|
||||
elif isinstance(data, tuple):
|
||||
data = tuple(to_mx_ndarray(d) for d in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
data = nd.array(data)
|
||||
elif isinstance(data, NDArray):
|
||||
pass
|
||||
elif isinstance(data, int) or isinstance(data, float):
|
||||
data = nd.array([data])
|
||||
else:
|
||||
raise TypeError('Unsupported data type: {}'.format(type(data)))
|
||||
return data
|
||||
|
||||
|
||||
def asnumpy_or_asscalar(data: Union[NDArray, list, tuple]) -> Union[np.ndarray, np.number, list, tuple]:
|
||||
"""
|
||||
Convert NDArray (or list or tuple of NDArray) to numpy. If shape is (1,), then convert to scalar instead.
|
||||
NOTE: This behavior is consistent with tensorflow
|
||||
:param data: NDArray or list or tuple of NDArray
|
||||
:return: data converted to numpy ndarray or to numpy scalar
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
data = [asnumpy_or_asscalar(d) for d in data]
|
||||
elif isinstance(data, tuple):
|
||||
data = tuple(asnumpy_or_asscalar(d) for d in data)
|
||||
elif isinstance(data, NDArray):
|
||||
data = data.asscalar() if data.shape == (1,) else data.asnumpy()
|
||||
else:
|
||||
raise TypeError('Unsupported data type: {}'.format(type(data)))
|
||||
return data
|
||||
|
||||
|
||||
def global_norm(arrays: Union[Generator[NDArray, NDArray, NDArray], List[NDArray], Tuple[NDArray]]) -> NDArray:
|
||||
"""
|
||||
Calculate global norm on list or tuple of NDArrays using this formula:
|
||||
`global_norm = sqrt(sum([l2norm(p)**2 for p in parameters]))`
|
||||
|
||||
:param arrays: list or tuple of parameters to calculate global norm on
|
||||
:return: single-value NDArray
|
||||
"""
|
||||
def _norm(array):
|
||||
if array.stype == 'default':
|
||||
x = array.reshape((-1,))
|
||||
return nd.dot(x, x)
|
||||
return array.norm().square()
|
||||
|
||||
total_norm = nd.add_n(*[_norm(arr) for arr in arrays])
|
||||
total_norm = nd.sqrt(total_norm)
|
||||
return total_norm
|
||||
|
||||
|
||||
def split_outputs_per_head(outputs: Tuple[NDArray], heads: list) -> List[List[NDArray]]:
|
||||
"""
|
||||
Split outputs into outputs per head
|
||||
:param outputs: list of all outputs
|
||||
:param heads: list of all heads
|
||||
:return: list of outputs for each head
|
||||
"""
|
||||
head_outputs = []
|
||||
for h in heads:
|
||||
head_outputs.append(list(outputs[:h.num_outputs]))
|
||||
outputs = outputs[h.num_outputs:]
|
||||
assert len(outputs) == 0
|
||||
return head_outputs
|
||||
|
||||
|
||||
def split_targets_per_loss(targets: list, losses: list) -> List[list]:
|
||||
"""
|
||||
Splits targets into targets per loss
|
||||
:param targets: list of all targets (typically numpy ndarray)
|
||||
:param losses: list of all losses
|
||||
:return: list of targets for each loss
|
||||
"""
|
||||
loss_targets = list()
|
||||
for l in losses:
|
||||
loss_data_len = len(l.input_schema.targets)
|
||||
assert len(targets) >= loss_data_len, "Data length doesn't match schema"
|
||||
loss_targets.append(targets[:loss_data_len])
|
||||
targets = targets[loss_data_len:]
|
||||
assert len(targets) == 0
|
||||
return loss_targets
|
||||
|
||||
|
||||
def get_loss_agent_inputs(inputs: Dict[str, np.ndarray], head_type_idx: int, loss: Any) -> List[np.ndarray]:
|
||||
"""
|
||||
Collects all inputs with prefix 'output_<head_idx>_' and matches them against agent_inputs in loss input schema.
|
||||
:param inputs: list of all agent inputs
|
||||
:param head_type_idx: head-type index of the corresponding head
|
||||
:param loss: corresponding loss
|
||||
:return: list of agent inputs for this loss. This list matches the length in loss input schema.
|
||||
"""
|
||||
loss_inputs = list()
|
||||
for k in sorted(inputs.keys()):
|
||||
if k.startswith('output_{}_'.format(head_type_idx)):
|
||||
loss_inputs.append(inputs[k])
|
||||
# Enforce that number of inputs for head_type are the same as agent_inputs specified by loss input_schema
|
||||
assert len(loss_inputs) == len(loss.input_schema.agent_inputs), "agent_input length doesn't match schema"
|
||||
return loss_inputs
|
||||
|
||||
|
||||
def align_loss_args(
|
||||
head_outputs: List[NDArray],
|
||||
agent_inputs: List[np.ndarray],
|
||||
targets: List[np.ndarray],
|
||||
loss: Any) -> List[np.ndarray]:
|
||||
"""
|
||||
Creates a list of arguments from head_outputs, agent_inputs, and targets aligned with parameters of
|
||||
loss.loss_forward() based on their name in loss input_schema
|
||||
:param head_outputs: list of all head_outputs for this loss
|
||||
:param agent_inputs: list of all agent_inputs for this loss
|
||||
:param targets: list of all targets for this loss
|
||||
:param loss: corresponding loss
|
||||
:return: list of arguments in correct order to be passed to loss
|
||||
"""
|
||||
arg_list = list()
|
||||
schema = loss.input_schema
|
||||
assert len(schema.head_outputs) == len(head_outputs)
|
||||
assert len(schema.agent_inputs) == len(agent_inputs)
|
||||
assert len(schema.targets) == len(targets)
|
||||
|
||||
prev_found = True
|
||||
for arg_name in inspect.getfullargspec(loss.loss_forward).args[2:]: # First two args are self and F
|
||||
found = False
|
||||
for schema_list, data in [(schema.head_outputs, head_outputs),
|
||||
(schema.agent_inputs, agent_inputs),
|
||||
(schema.targets, targets)]:
|
||||
try:
|
||||
arg_list.append(data[schema_list.index(arg_name)])
|
||||
found = True
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
assert not found or prev_found, "missing arguments detected!"
|
||||
prev_found = found
|
||||
return arg_list
|
||||
|
||||
|
||||
def to_tuple(data: Union[tuple, list, Any]):
|
||||
"""
|
||||
If input is list, it is converted to tuple. If it's tuple, it is returned untouched. Otherwise
|
||||
returns a single-element tuple of the data.
|
||||
:return: tuple-ified data
|
||||
"""
|
||||
if isinstance(data, tuple):
|
||||
pass
|
||||
elif isinstance(data, list):
|
||||
data = tuple(data)
|
||||
else:
|
||||
data = (data,)
|
||||
return data
|
||||
|
||||
|
||||
def to_list(data: Union[tuple, list, Any]):
|
||||
"""
|
||||
If input is tuple, it is converted to list. If it's list, it is returned untouched. Otherwise
|
||||
returns a single-element list of the data.
|
||||
:return: list-ified data
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
pass
|
||||
elif isinstance(data, tuple):
|
||||
data = list(data)
|
||||
else:
|
||||
data = [data]
|
||||
return data
|
||||
|
||||
|
||||
def loss_output_dict(output: List[NDArray], schema: List[str]) -> Dict[str, List[NDArray]]:
|
||||
"""
|
||||
Creates a dictionary for loss output based on the output schema. If two output values have the same
|
||||
type string in the schema they are concatenated in the same dicrionary item.
|
||||
:param output: list of output values
|
||||
:param schema: list of type-strings for output values
|
||||
:return: dictionary of keyword to list of NDArrays
|
||||
"""
|
||||
assert len(output) == len(schema)
|
||||
output_dict = dict()
|
||||
for name, val in zip(schema, output):
|
||||
if name in output_dict:
|
||||
output_dict[name].append(val)
|
||||
else:
|
||||
output_dict[name] = [val]
|
||||
return output_dict
|
||||
|
||||
|
||||
def clip_grad(
|
||||
grads: Union[Generator[NDArray, NDArray, NDArray], List[NDArray], Tuple[NDArray]],
|
||||
clip_method: GradientClippingMethod,
|
||||
clip_val: float,
|
||||
inplace=True) -> List[NDArray]:
|
||||
"""
|
||||
Clip gradient values inplace
|
||||
:param grads: gradients to be clipped
|
||||
:param clip_method: clipping method
|
||||
:param clip_val: clipping value. Interpreted differently depending on clipping method.
|
||||
:param inplace: modify grads if True, otherwise create NDArrays
|
||||
:return: clipped gradients
|
||||
"""
|
||||
output = list(grads) if inplace else list(nd.empty(g.shape) for g in grads)
|
||||
if clip_method == GradientClippingMethod.ClipByGlobalNorm:
|
||||
norm_unclipped_grads = global_norm(grads)
|
||||
scale = clip_val / (norm_unclipped_grads.asscalar() + 1e-8) # todo: use branching operators?
|
||||
if scale < 1.0:
|
||||
for g, o in zip(grads, output):
|
||||
nd.broadcast_mul(g, nd.array([scale]), out=o)
|
||||
elif clip_method == GradientClippingMethod.ClipByValue:
|
||||
for g, o in zip(grads, output):
|
||||
g.clip(-clip_val, clip_val, out=o)
|
||||
elif clip_method == GradientClippingMethod.ClipByNorm:
|
||||
for g, o in zip(grads, output):
|
||||
nd.broadcast_mul(g, nd.minimum(1.0, clip_val / (g.norm() + 1e-8)), out=o)
|
||||
else:
|
||||
raise KeyError('Unsupported gradient clipping method')
|
||||
return output
|
||||
|
||||
|
||||
def hybrid_clip(F: ModuleType, x: nd_sym_type, clip_lower: nd_sym_type, clip_upper: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Apply clipping to input x between clip_lower and clip_upper.
|
||||
Added because F.clip doesn't support clipping bounds that are mx.nd.NDArray or mx.sym.Symbol.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: input data
|
||||
:param clip_lower: lower bound used for clipping, should be of shape (1,)
|
||||
:param clip_upper: upper bound used for clipping, should be of shape (1,)
|
||||
:return: clipped data
|
||||
"""
|
||||
x_clip_lower = clip_lower.broadcast_like(x)
|
||||
x_clip_upper = clip_upper.broadcast_like(x)
|
||||
x_clipped = F.stack(x, x_clip_lower, axis=0).max(axis=0)
|
||||
x_clipped = F.stack(x_clipped, x_clip_upper, axis=0).min(axis=0)
|
||||
return x_clipped
|
||||
|
||||
|
||||
def get_mxnet_activation_name(activation_name: str):
|
||||
"""
|
||||
Convert coach activation name to mxnet specific activation name
|
||||
:param activation_name: name of the activation inc coach
|
||||
:return: name of the activation in mxnet
|
||||
"""
|
||||
activation_functions = {
|
||||
'relu': 'relu',
|
||||
'tanh': 'tanh',
|
||||
'sigmoid': 'sigmoid',
|
||||
# FIXME Add other activations
|
||||
# 'elu': tf.nn.elu,
|
||||
'selu': 'softrelu',
|
||||
# 'leaky_relu': tf.nn.leaky_relu,
|
||||
'none': None
|
||||
}
|
||||
assert activation_name in activation_functions, \
|
||||
"Activation function must be one of the following {}. instead it was: {}".format(
|
||||
activation_functions.keys(), activation_name)
|
||||
return activation_functions[activation_name]
|
||||
@@ -183,16 +183,7 @@ class NetworkWrapper(object):
|
||||
target_network or global_network) and the second element is the inputs
|
||||
:return: the outputs of all the networks in the same order as the inputs were given
|
||||
"""
|
||||
feed_dict = {}
|
||||
fetches = []
|
||||
|
||||
for idx, (network, input) in enumerate(network_input_tuples):
|
||||
feed_dict.update(network.create_feed_dict(input))
|
||||
fetches += network.outputs
|
||||
|
||||
outputs = self.sess.run(fetches, feed_dict)
|
||||
|
||||
return outputs
|
||||
return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
|
||||
|
||||
def get_local_variables(self):
|
||||
"""
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import os
|
||||
import time
|
||||
from typing import Any, List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@@ -544,6 +545,26 @@ class TensorFlowArchitecture(Architecture):
|
||||
output = squeeze_list(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def parallel_predict(sess: Any,
|
||||
network_input_tuples: List[Tuple['TensorFlowArchitecture', Dict[str, np.ndarray]]]) ->\
|
||||
List[np.ndarray]:
|
||||
"""
|
||||
:param sess: active session to use for prediction
|
||||
:param network_input_tuples: tuple of network and corresponding input
|
||||
:return: list of outputs from all networks
|
||||
"""
|
||||
feed_dict = {}
|
||||
fetches = []
|
||||
|
||||
for network, input in network_input_tuples:
|
||||
feed_dict.update(network.create_feed_dict(input))
|
||||
fetches += network.outputs
|
||||
|
||||
outputs = sess.run(fetches, feed_dict)
|
||||
|
||||
return outputs
|
||||
|
||||
def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None, importance_weights=None):
|
||||
"""
|
||||
Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients
|
||||
|
||||
Reference in New Issue
Block a user