From 5fadb9c18e3de16cc5633175199f9e9e2c381102 Mon Sep 17 00:00:00 2001 From: Sina Afrooze Date: Wed, 7 Nov 2018 07:07:15 -0800 Subject: [PATCH] Adding mxnet components to rl_coach/architectures (#60) Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet --- docker/Dockerfile | 4 +- rl_coach/architectures/architecture.py | 37 +- .../mxnet_components/__init__.py | 0 .../mxnet_components/architecture.py | 405 +++++++++++ .../mxnet_components/embedders/__init__.py | 4 + .../mxnet_components/embedders/embedder.py | 71 ++ .../embedders/image_embedder.py | 76 ++ .../embedders/vector_embedder.py | 71 ++ .../mxnet_components/general_network.py | 501 +++++++++++++ .../mxnet_components/heads/__init__.py | 14 + .../mxnet_components/heads/head.py | 181 +++++ .../mxnet_components/heads/ppo_head.py | 669 ++++++++++++++++++ .../mxnet_components/heads/ppo_v_head.py | 123 ++++ .../mxnet_components/heads/q_head.py | 106 +++ .../mxnet_components/heads/v_head.py | 100 +++ .../architectures/mxnet_components/layers.py | 99 +++ .../mxnet_components/middlewares/__init__.py | 4 + .../middlewares/fc_middleware.py | 52 ++ .../middlewares/lstm_middleware.py | 80 +++ .../middlewares/middleware.py | 61 ++ .../architectures/mxnet_components/utils.py | 280 ++++++++ rl_coach/architectures/network_wrapper.py | 11 +- .../tensorflow_components/architecture.py | 23 +- rl_coach/graph_managers/graph_manager.py | 73 +- rl_coach/presets/CartPole_PPO.py | 4 +- .../mxnet_components/__init__.py | 0 .../mxnet_components/embedders/__init__.py | 0 .../embedders/test_image_embedder.py | 21 + .../embedders/test_vector_embedder.py | 22 + .../mxnet_components/heads/__init__.py | 0 .../mxnet_components/heads/test_ppo_head.py | 406 +++++++++++ .../mxnet_components/heads/test_ppo_v_head.py | 90 +++ .../mxnet_components/heads/test_q_head.py | 60 ++ .../mxnet_components/heads/test_v_head.py | 57 ++ .../mxnet_components/middlewares/__init__.py | 0 .../middlewares/test_fc_middleware.py | 22 + .../middlewares/test_lstm_middleware.py | 25 + .../mxnet_components/test_utils.py | 144 ++++ setup.py | 12 + 39 files changed, 3864 insertions(+), 44 deletions(-) create mode 100644 rl_coach/architectures/mxnet_components/__init__.py create mode 100644 rl_coach/architectures/mxnet_components/architecture.py create mode 100644 rl_coach/architectures/mxnet_components/embedders/__init__.py create mode 100644 rl_coach/architectures/mxnet_components/embedders/embedder.py create mode 100644 rl_coach/architectures/mxnet_components/embedders/image_embedder.py create mode 100644 rl_coach/architectures/mxnet_components/embedders/vector_embedder.py create mode 100644 rl_coach/architectures/mxnet_components/general_network.py create mode 100644 rl_coach/architectures/mxnet_components/heads/__init__.py create mode 100644 rl_coach/architectures/mxnet_components/heads/head.py create mode 100644 rl_coach/architectures/mxnet_components/heads/ppo_head.py create mode 100644 rl_coach/architectures/mxnet_components/heads/ppo_v_head.py create mode 100644 rl_coach/architectures/mxnet_components/heads/q_head.py create mode 100644 rl_coach/architectures/mxnet_components/heads/v_head.py create mode 100644 rl_coach/architectures/mxnet_components/layers.py create mode 100644 rl_coach/architectures/mxnet_components/middlewares/__init__.py create mode 100644 rl_coach/architectures/mxnet_components/middlewares/fc_middleware.py create mode 100644 rl_coach/architectures/mxnet_components/middlewares/lstm_middleware.py create mode 100644 rl_coach/architectures/mxnet_components/middlewares/middleware.py create mode 100644 rl_coach/architectures/mxnet_components/utils.py create mode 100644 rl_coach/tests/architectures/mxnet_components/__init__.py create mode 100644 rl_coach/tests/architectures/mxnet_components/embedders/__init__.py create mode 100644 rl_coach/tests/architectures/mxnet_components/embedders/test_image_embedder.py create mode 100644 rl_coach/tests/architectures/mxnet_components/embedders/test_vector_embedder.py create mode 100644 rl_coach/tests/architectures/mxnet_components/heads/__init__.py create mode 100644 rl_coach/tests/architectures/mxnet_components/heads/test_ppo_head.py create mode 100644 rl_coach/tests/architectures/mxnet_components/heads/test_ppo_v_head.py create mode 100644 rl_coach/tests/architectures/mxnet_components/heads/test_q_head.py create mode 100644 rl_coach/tests/architectures/mxnet_components/heads/test_v_head.py create mode 100644 rl_coach/tests/architectures/mxnet_components/middlewares/__init__.py create mode 100644 rl_coach/tests/architectures/mxnet_components/middlewares/test_fc_middleware.py create mode 100644 rl_coach/tests/architectures/mxnet_components/middlewares/test_lstm_middleware.py create mode 100644 rl_coach/tests/architectures/mxnet_components/test_utils.py diff --git a/docker/Dockerfile b/docker/Dockerfile index 1fb5dc3..1237e64 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,12 +5,12 @@ COPY setup.py /root/src/. COPY requirements.txt /root/src/. COPY README.md /root/src/. WORKDIR /root/src -RUN pip3 install -e . +RUN pip3 install -e .[all] # everything above here should be cached most of the time COPY . /root/src WORKDIR /root/src -RUN pip3 install -e . +RUN pip3 install -e .[all] RUN chmod 777 /root/src/docker/docker_entrypoint.sh ENTRYPOINT ["/root/src/docker/docker_entrypoint.sh"] diff --git a/rl_coach/architectures/architecture.py b/rl_coach/architectures/architecture.py index 1ae2d47..78b66cb 100644 --- a/rl_coach/architectures/architecture.py +++ b/rl_coach/architectures/architecture.py @@ -41,22 +41,41 @@ class Architecture(object): self.optimizer = None self.ap = agent_parameters - def predict(self, inputs: Dict[str, np.ndarray]) -> List[np.ndarray]: + def predict(self, + inputs: Dict[str, np.ndarray], + outputs: List[Any] = None, + squeeze_output: bool = True, + initial_feed_dict: Dict[Any, np.ndarray] = None) -> Tuple[np.ndarray, ...]: """ Given input observations, use the model to make predictions (e.g. action or value). :param inputs: current state (i.e. observations, measurements, goals, etc.) (e.g. `{'observation': numpy.ndarray}` of shape (batch_size, observation_space_size)) + :param outputs: list of outputs to return. Return all outputs if unspecified. Type of the list elements + depends on the framework backend. + :param squeeze_output: call squeeze_list on output before returning if True + :param initial_feed_dict: a dictionary of extra inputs for forward pass. :return: predictions of action or value of shape (batch_size, action_space_size) for action predictions) """ pass + @staticmethod + def parallel_predict(sess: Any, + network_input_tuples: List[Tuple['Architecture', Dict[str, np.ndarray]]]) -> \ + Tuple[np.ndarray, ...]: + """ + :param sess: active session to use for prediction + :param network_input_tuples: tuple of network and corresponding input + :return: list or tuple of outputs from all networks + """ + pass + def train_on_batch(self, inputs: Dict[str, np.ndarray], targets: List[np.ndarray], scaler: float=1., additional_fetches: list=None, - importance_weights: np.ndarray=None) -> tuple: + importance_weights: np.ndarray=None) -> Tuple[float, List[float], float, list]: """ Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to @@ -118,8 +137,7 @@ class Architecture(object): targets: List[np.ndarray], additional_fetches: list=None, importance_weights: np.ndarray=None, - no_accumulation: bool=False) ->\ - Tuple[float, List[float], float, list]: + no_accumulation: bool=False) -> Tuple[float, List[float], float, list]: """ Given a batch of inputs (i.e. states) and targets (e.g. discounted rewards), computes and accumulates the gradients for model parameters. Will run forward and backward pass to compute gradients, clip the gradient @@ -142,30 +160,33 @@ class Architecture(object): calculated gradients :return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors total_loss (float): sum of all head losses - losses (list of float): list of all losses. The order is list of target losses followed by list of regularization losses. - The specifics of losses is dependant on the network parameters (number of heads, etc.) + losses (list of float): list of all losses. The order is list of target losses followed by list of + regularization losses. The specifics of losses is dependant on the network parameters + (number of heads, etc.) norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied fetched_tensors: all values for additional_fetches """ pass - def apply_and_reset_gradients(self, gradients: List[np.ndarray]) -> None: + def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None: """ Applies the given gradients to the network weights and resets the gradient accumulations. Has the same impact as calling `apply_gradients`, then `reset_accumulated_gradients`. :param gradients: gradients for the parameter weights, taken from `accumulated_gradients` property of an identical network (either self or another identical network) + :param scaler: A scaling factor that allows rescaling the gradients before applying them """ pass - def apply_gradients(self, gradients: List[np.ndarray]) -> None: + def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None: """ Applies the given gradients to the network weights. Will be performed sync or async depending on `network_parameters.async_training` :param gradients: gradients for the parameter weights, taken from `accumulated_gradients` property of an identical network (either self or another identical network) + :param scaler: A scaling factor that allows rescaling the gradients before applying them """ pass diff --git a/rl_coach/architectures/mxnet_components/__init__.py b/rl_coach/architectures/mxnet_components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl_coach/architectures/mxnet_components/architecture.py b/rl_coach/architectures/mxnet_components/architecture.py new file mode 100644 index 0000000..dd860fb --- /dev/null +++ b/rl_coach/architectures/mxnet_components/architecture.py @@ -0,0 +1,405 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import copy +from typing import Any, Dict, Generator, List, Tuple, Union + +import numpy as np +import mxnet as mx +from mxnet import autograd, gluon, nd +from mxnet.ndarray import NDArray + +from rl_coach.architectures.architecture import Architecture +from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION +from rl_coach.architectures.mxnet_components import utils +from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters +from rl_coach.spaces import SpacesDefinition +from rl_coach.utils import force_list, squeeze_list + + +class MxnetArchitecture(Architecture): + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "", + global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): + """ + :param agent_parameters: the agent parameters + :param spaces: the spaces definition of the agent + :param name: the name of the network + :param global_network: the global network replica that is shared between all the workers + :param network_is_local: is the network global (shared between workers) or local (dedicated to the worker) + :param network_is_trainable: is the network trainable (we can apply gradients on it) + """ + super().__init__(agent_parameters, spaces, name) + self.middleware = None + self.network_is_local = network_is_local + self.global_network = global_network + if not self.network_parameters.tensorflow_support: + raise ValueError('TensorFlow is not supported for this agent') + self.losses = [] # type: List[HeadLoss] + self.shared_accumulated_gradients = [] + self.curr_rnn_c_in = None + self.curr_rnn_h_in = None + self.gradients_wrt_inputs = [] + self.train_writer = None + self.accumulated_gradients = None + self.network_is_trainable = network_is_trainable + self.is_training = False + self.model = None # type: GeneralModel + + self.is_chief = self.ap.task_parameters.task_index == 0 + self.network_is_global = not self.network_is_local and global_network is None + self.distributed_training = self.network_is_global or self.network_is_local and global_network is not None + + self.optimizer_type = self.network_parameters.optimizer_type + if self.ap.task_parameters.seed is not None: + mx.random.seed(self.ap.task_parameters.seed) + + # Call to child class to create the model + self.construct_model() + + self.trainer = None # type: gluon.Trainer + + def __str__(self): + return self.model.summary(*self._dummy_model_inputs()) + + @property + def _model_grads(self) -> Generator[NDArray, NDArray, Any]: + """ + Creates a copy of model gradients and returns them in a list, in the same order as collect_params() + :return: a generator for model gradient values + """ + return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null') + + def _dummy_model_inputs(self) -> Tuple[NDArray, ...]: + """ + Creates a tuple of input arrays with correct shapes that can be used for shape inference + of the model weights and for printing the summary + :return: tuple of inputs for model forward pass + """ + allowed_inputs = copy.copy(self.spaces.state.sub_spaces) + allowed_inputs["action"] = copy.copy(self.spaces.action) + allowed_inputs["goal"] = copy.copy(self.spaces.goal) + embedders = self.model.nets[0].input_embedders + inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders) + return inputs + + def construct_model(self) -> None: + """ + Construct network model. Implemented by child class. + """ + raise NotImplementedError + + def set_session(self, sess) -> None: + """ + Initializes the model parameters and creates the model trainer. + NOTEL Session for mxnet backend must be None. + :param sess: must be None + """ + assert sess is None + # FIXME Add GPU initialization + # FIXME Add initializer + self.model.collect_params().initialize(ctx=mx.cpu()) + # Hybridize model and losses + self.model.hybridize() + for l in self.losses: + l.hybridize() + + # Pass dummy data with correct shape to trigger shape inference and full parameter initialization + self.model(*self._dummy_model_inputs()) + + if self.network_is_trainable: + self.trainer = gluon.Trainer( + self.model.collect_params(), optimizer=self.optimizer, update_on_kvstore=False) + + def reset_accumulated_gradients(self) -> None: + """ + Reset model gradients as well as accumulated gradients to zero. If accumulated gradients + have not been created yet, it constructs them on CPU. + """ + # Set model gradients to zero + for p in self.model.collect_params().values(): + p.zero_grad() + # Set accumulated gradients to zero if already initialized, otherwise create a copy + if self.accumulated_gradients: + for a in self.accumulated_gradients: + a *= 0 + else: + self.accumulated_gradients = [g.copy() for g in self._model_grads] + + def accumulate_gradients(self, + inputs: Dict[str, np.ndarray], + targets: List[np.ndarray], + additional_fetches: List[Tuple[int, str]] = None, + importance_weights: np.ndarray = None, + no_accumulation: bool = False) -> Tuple[float, List[float], float, list]: + """ + Runs a forward & backward pass, clips gradients if needed and accumulates them into the accumulation + :param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray + is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size) + :param targets: targets required by loss (e.g. sum of discounted rewards) + :param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str) + tuple of head-type-index and fetch-name. The tuple is obtained from each head. + :param importance_weights: ndarray of shape (batch_size,) to multiply with batch loss. + :param no_accumulation: if True, set gradient values to the new gradients, otherwise sum with previously + calculated gradients + :return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors + total_loss (float): sum of all head losses + losses (list of float): list of all losses. The order is list of target losses followed by list of + regularization losses. The specifics of losses is dependant on the network parameters + (number of heads, etc.) + norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied + fetched_tensors: all values for additional_fetches + """ + if self.accumulated_gradients is None: + self.reset_accumulated_gradients() + + embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders] + nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders) + + assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported" + + targets = force_list(targets) + with autograd.record(): + out_per_head = utils.split_outputs_per_head(self.model(*nd_inputs), self.model.output_heads) + tgt_per_loss = utils.split_targets_per_loss(targets, self.losses) + + losses = list() + regularizations = list() + additional_fetches = [(k, None) for k in additional_fetches] + for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss): + l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss) + # Align arguments with loss.loss_forward and convert to NDArray + l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss)) + # Calculate loss and all auxiliary outputs + loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema) + if LOSS_OUT_TYPE_LOSS in loss_outputs: + losses.extend(loss_outputs[LOSS_OUT_TYPE_LOSS]) + if LOSS_OUT_TYPE_REGULARIZATION in loss_outputs: + regularizations.extend(loss_outputs[LOSS_OUT_TYPE_REGULARIZATION]) + # Set additional fetches + for i, fetch in enumerate(additional_fetches): + head_type_idx, fetch_name = fetch[0] # fetch key is a tuple of (head_type_index, fetch_name) + if head_type_idx == h.head_type_idx: + assert fetch[1] is None # sanity check that fetch is None + additional_fetches[i] = (fetch[0], loss_outputs[fetch_name]) + + # Total loss is losses and regularization (NOTE: order is important) + total_loss_list = losses + regularizations + total_loss = nd.add_n(*total_loss_list) + + # Calculate gradients + total_loss.backward() + + assert self.optimizer_type != 'LBFGS', 'LBFGS not supported' + + # allreduce gradients from all contexts + self.trainer.allreduce_grads() + + # Calculate global norm of gradients + # FIXME global norm is returned even when not used for clipping! Is this necessary? + # FIXME global norm might be calculated twice if clipping method is global norm + norm_unclipped_grads = utils.global_norm(self._model_grads) + + # Clip gradients + if self.network_parameters.clip_gradients: + utils.clip_grad( + self._model_grads, + clip_method=self.network_parameters.gradients_clipping_method, + clip_val=self.network_parameters.clip_gradients, + inplace=True) + + # Update self.accumulated_gradients depending on no_accumulation flag + if no_accumulation: + for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads): + acc_grad[:] = model_grad + else: + for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads): + acc_grad += model_grad + + # result of of additional fetches + fetched_tensors = [fetch[1] for fetch in additional_fetches] + + # convert everything to numpy or scalar before returning + result = utils.asnumpy_or_asscalar((total_loss, total_loss_list, norm_unclipped_grads, fetched_tensors)) + return result + + def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None: + """ + Applies the given gradients to the network weights and resets accumulated gradients to zero + :param gradients: The gradients to use for the update + :param scaler: A scaling factor that allows rescaling the gradients before applying them + """ + self.apply_gradients(gradients, scaler) + self.reset_accumulated_gradients() + + def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None: + """ + Applies the given gradients to the network weights + :param gradients: The gradients to use for the update + :param scaler: A scaling factor that allows rescaling the gradients before applying them. + The gradients will be MULTIPLIED by this factor + """ + assert self.optimizer_type != 'LBFGS' + + batch_size = 1 + if self.distributed_training and not self.network_parameters.async_training: + # rescale the gradients so that they average out with the gradients from the other workers + if self.network_parameters.scale_down_gradients_by_number_of_workers_for_sync_training: + batch_size = self.ap.task_parameters.num_training_tasks + + # set parameter gradients to gradients passed in + for param_grad, gradient in zip(self._model_grads, gradients): + param_grad[:] = gradient + # update gradients + self.trainer.update(batch_size=batch_size) + + def _predict(self, inputs: Dict[str, np.ndarray]) -> Tuple[NDArray, ...]: + """ + Run a forward pass of the network using the given input + :param inputs: The input dictionary for the network. Key is name of the embedder. + :return: The network output + + WARNING: must only call once per state since each call is assumed by LSTM to be a new time step. + """ + embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders] + nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders) + + assert self.middleware.__class__.__name__ != 'LSTMMiddleware' + + output = self.model(*nd_inputs) + return output + + def predict(self, + inputs: Dict[str, np.ndarray], + outputs: List[str]=None, + squeeze_output: bool=True, + initial_feed_dict: Dict[str, np.ndarray]=None) -> Tuple[np.ndarray, ...]: + """ + Run a forward pass of the network using the given input + :param inputs: The input dictionary for the network. Key is name of the embedder. + :param outputs: list of outputs to return. Return all outputs if unspecified (currently not supported) + :param squeeze_output: call squeeze_list on output if True + :param initial_feed_dict: a dictionary of extra inputs for forward pass (currently not supported) + :return: The network output + + WARNING: must only call once per state since each call is assumed by LSTM to be a new time step. + """ + assert initial_feed_dict is None, "initial_feed_dict must be None" + assert outputs is None, "outputs must be None" + + output = self._predict(inputs) + + output = tuple(o.asnumpy() for o in output) + if squeeze_output: + output = squeeze_list(output) + return output + + @staticmethod + def parallel_predict(sess: Any, + network_input_tuples: List[Tuple['MxnetArchitecture', Dict[str, np.ndarray]]]) -> \ + Tuple[np.ndarray, ...]: + """ + :param sess: active session to use for prediction (must be None for MXNet) + :param network_input_tuples: tuple of network and corresponding input + :return: tuple of outputs from all networks + """ + assert sess is None + output = list() + for net, inputs in network_input_tuples: + output += net._predict(inputs) + return tuple(o.asnumpy() for o in output) + + def train_on_batch(self, + inputs: Dict[str, np.ndarray], + targets: List[np.ndarray], + scaler: float = 1., + additional_fetches: list = None, + importance_weights: np.ndarray = None) -> Tuple[float, List[float], float, list]: + """ + Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a + forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to + update the weights. + :param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray + is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size) + :param targets: targets required by loss (e.g. sum of discounted rewards) + :param scaler: value to scale gradients by before optimizing network weights + :param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str) + tuple of head-type-index and fetch-name. The tuple is obtained from each head. + :param importance_weights: ndarray of shape (batch_size,) to multiply with batch loss. + :return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors + total_loss (float): sum of all head losses + losses (list of float): list of all losses. The order is list of target losses followed by list + of regularization losses. The specifics of losses is dependant on the network parameters + (number of heads, etc.) + norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied + fetched_tensors: all values for additional_fetches + """ + loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches, + importance_weights=importance_weights) + self.apply_and_reset_gradients(self.accumulated_gradients, scaler) + return loss + + def get_weights(self) -> gluon.ParameterDict: + """ + :return: a ParameterDict containing all network weights + """ + return self.model.collect_params() + + def set_weights(self, weights: gluon.ParameterDict, new_rate: float=1.0) -> None: + """ + Sets the network weights from the given ParameterDict + :param new_rate: ratio for adding new and old weight values: val=rate*weights + (1-rate)*old_weights + """ + old_weights = self.model.collect_params() + for name, p in weights.items(): + name = name[len(weights.prefix):] # Strip prefix + old_p = old_weights[old_weights.prefix + name] # Add prefix + old_p.set_data(new_rate * p._reduce() + (1 - new_rate) * old_p._reduce()) + + def get_variable_value(self, variable: Union[gluon.Parameter, NDArray]) -> np.ndarray: + """ + Get the value of a variable + :param variable: the variable + :return: the value of the variable + """ + if isinstance(variable, gluon.Parameter): + variable = variable._reduce().asnumpy() + if isinstance(variable, NDArray): + return variable.asnumpy() + return variable + + def set_variable_value(self, assign_op: callable, value: Any, placeholder=None) -> None: + """ + Updates value of a variable. + :param assign_op: a callable assign function for setting the variable + :param value: a value to set the variable to + :param placeholder: unused (placeholder in symbolic framework backends) + """ + assert callable(assign_op) + assign_op(value) + + def set_is_training(self, state: bool) -> None: + """ + Set the phase of the network between training and testing + :param state: The current state (True = Training, False = Testing) + :return: None + """ + self.is_training = state + + def reset_internal_memory(self) -> None: + """ + Reset any internal memory used by the network. For example, an LSTM internal state + :return: None + """ + assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported' diff --git a/rl_coach/architectures/mxnet_components/embedders/__init__.py b/rl_coach/architectures/mxnet_components/embedders/__init__.py new file mode 100644 index 0000000..eb0482f --- /dev/null +++ b/rl_coach/architectures/mxnet_components/embedders/__init__.py @@ -0,0 +1,4 @@ +from .image_embedder import ImageEmbedder +from .vector_embedder import VectorEmbedder + +__all__ = ['ImageEmbedder', 'VectorEmbedder'] diff --git a/rl_coach/architectures/mxnet_components/embedders/embedder.py b/rl_coach/architectures/mxnet_components/embedders/embedder.py new file mode 100644 index 0000000..c2b6340 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/embedders/embedder.py @@ -0,0 +1,71 @@ +from typing import Union +from types import ModuleType + +import mxnet as mx +from mxnet.gluon import nn +from rl_coach.architectures.embedder_parameters import InputEmbedderParameters +from rl_coach.architectures.mxnet_components.layers import convert_layer +from rl_coach.base_parameters import EmbedderScheme + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class InputEmbedder(nn.HybridBlock): + def __init__(self, params: InputEmbedderParameters): + """ + An input embedder is the first part of the network, which takes the input from the state and produces a vector + embedding by passing it through a neural network. The embedder will mostly be input type dependent, and there + can be multiple embedders in a single network. + + :param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function + and dropout properties. + """ + super(InputEmbedder, self).__init__() + self.embedder_name = params.name + self.input_clipping = params.input_clipping + self.scheme = params.scheme + + with self.name_scope(): + self.net = nn.HybridSequential() + if isinstance(self.scheme, EmbedderScheme): + blocks = self.schemes[self.scheme] + else: + # if scheme is specified directly, convert to MX layer if it's not a callable object + # NOTE: if layer object is callable, it must return a gluon block when invoked + blocks = [convert_layer(l) for l in self.scheme] + for block in blocks: + self.net.add(block()) + if params.batchnorm: + self.net.add(nn.BatchNorm()) + if params.activation_function: + self.net.add(nn.Activation(params.activation_function)) + if params.dropout: + self.net.add(nn.Dropout(rate=params.dropout)) + + @property + def schemes(self) -> dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used for the + InputEmbedder. Should be implemented in child classes, and are used to create Block when InputEmbedder is + initialised. + + :return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block. + """ + raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default " + "configurations.") + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type: + """ + Used for forward pass through embedder network. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: environment state, where first dimension is batch_size, then dimensions are data type dependent. + :return: embedding of environment state, where shape is (batch_size, channels). + """ + # `input_rescaling` and `input_offset` set on inheriting embedder + x = x / self.input_rescaling + x = x - self.input_offset + if self.input_clipping is not None: + x.clip(a_min=self.input_clipping[0], a_max=self.input_clipping[1]) + x = self.net(x) + return x.flatten() diff --git a/rl_coach/architectures/mxnet_components/embedders/image_embedder.py b/rl_coach/architectures/mxnet_components/embedders/image_embedder.py new file mode 100644 index 0000000..36842d8 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/embedders/image_embedder.py @@ -0,0 +1,76 @@ +from typing import Union +from types import ModuleType + +import mxnet as mx +from rl_coach.architectures.embedder_parameters import InputEmbedderParameters +from rl_coach.architectures.mxnet_components.embedders.embedder import InputEmbedder +from rl_coach.architectures.mxnet_components.layers import Conv2d +from rl_coach.base_parameters import EmbedderScheme + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class ImageEmbedder(InputEmbedder): + def __init__(self, params: InputEmbedderParameters): + """ + An image embedder is an input embedder that takes an image input from the state and produces a vector + embedding by passing it through a neural network. + + :param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function + and dropout properties. + """ + super(ImageEmbedder, self).__init__(params) + self.input_rescaling = params.input_rescaling['image'] + self.input_offset = params.input_offset['image'] + + @property + def schemes(self) -> dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used + to create Block when ImageEmbedder is initialised. + + :return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block. + """ + return { + EmbedderScheme.Empty: + [], + + EmbedderScheme.Shallow: + [ + Conv2d(num_filters=32, kernel_size=8, strides=4) + ], + + # Use for Atari DQN + EmbedderScheme.Medium: + [ + Conv2d(num_filters=32, kernel_size=8, strides=4), + Conv2d(num_filters=64, kernel_size=4, strides=2), + Conv2d(num_filters=64, kernel_size=3, strides=1) + ], + + # Use for Carla + EmbedderScheme.Deep: + [ + Conv2d(num_filters=32, kernel_size=5, strides=2), + Conv2d(num_filters=32, kernel_size=3, strides=1), + Conv2d(num_filters=64, kernel_size=3, strides=2), + Conv2d(num_filters=64, kernel_size=3, strides=1), + Conv2d(num_filters=128, kernel_size=3, strides=2), + Conv2d(num_filters=128, kernel_size=3, strides=1), + Conv2d(num_filters=256, kernel_size=3, strides=2), + Conv2d(num_filters=256, kernel_size=3, strides=1) + ] + } + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type: + """ + Used for forward pass through embedder network. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: image representing environment state, of shape (batch_size, in_channels, height, width). + :return: embedding of environment state, of shape (batch_size, channels). + """ + if len(x.shape) != 4 and self.scheme != EmbedderScheme.Empty: + raise ValueError("Image embedders expect the input size to have 4 dimensions. The given size is: {}" + .format(x.shape)) + return super(ImageEmbedder, self).hybrid_forward(F, x, *args, **kwargs) diff --git a/rl_coach/architectures/mxnet_components/embedders/vector_embedder.py b/rl_coach/architectures/mxnet_components/embedders/vector_embedder.py new file mode 100644 index 0000000..4a36357 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/embedders/vector_embedder.py @@ -0,0 +1,71 @@ +from typing import Union +from types import ModuleType + +import mxnet as mx +from mxnet import nd, sym +from rl_coach.architectures.embedder_parameters import InputEmbedderParameters +from rl_coach.architectures.mxnet_components.embedders.embedder import InputEmbedder +from rl_coach.architectures.mxnet_components.layers import Dense +from rl_coach.base_parameters import EmbedderScheme + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class VectorEmbedder(InputEmbedder): + def __init__(self, params: InputEmbedderParameters): + """ + An vector embedder is an input embedder that takes an vector input from the state and produces a vector + embedding by passing it through a neural network. + + :param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function + and dropout properties. + """ + super(VectorEmbedder, self).__init__(params) + self.input_rescaling = params.input_rescaling['vector'] + self.input_offset = params.input_offset['vector'] + + @property + def schemes(self): + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used + to create Block when VectorEmbedder is initialised. + + :return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block. + """ + return { + EmbedderScheme.Empty: + [], + + EmbedderScheme.Shallow: + [ + Dense(units=128) + ], + + # Use for DQN + EmbedderScheme.Medium: + [ + Dense(units=256) + ], + + # Use for Carla + EmbedderScheme.Deep: + [ + Dense(units=128), + Dense(units=128), + Dense(units=128) + ] + } + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type: + """ + Used for forward pass through embedder network. + + :param F: backend api, either `nd` or `sym` (if block has been hybridized). + :type F: nd or sym + :param x: vector representing environment state, of shape (batch_size, in_channels). + :return: embedding of environment state, of shape (batch_size, channels). + """ + if isinstance(x, nd.NDArray) and len(x.shape) != 2 and self.scheme != EmbedderScheme.Empty: + raise ValueError("Vector embedders expect the input size to have 2 dimensions. The given size is: {}" + .format(x.shape)) + return super(VectorEmbedder, self).hybrid_forward(F, x, *args, **kwargs) diff --git a/rl_coach/architectures/mxnet_components/general_network.py b/rl_coach/architectures/mxnet_components/general_network.py new file mode 100644 index 0000000..bb1b176 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/general_network.py @@ -0,0 +1,501 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +from itertools import chain +from typing import List, Tuple, Union +from types import ModuleType + +import numpy as np +import mxnet as mx +from mxnet import nd, sym +from mxnet.gluon import HybridBlock +from mxnet.ndarray import NDArray +from mxnet.symbol import Symbol + +from rl_coach.base_parameters import NetworkParameters +from rl_coach.architectures.embedder_parameters import InputEmbedderParameters +from rl_coach.architectures.head_parameters import HeadParameters, PPOHeadParameters +from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadParameters, QHeadParameters +from rl_coach.architectures.middleware_parameters import MiddlewareParameters +from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters +from rl_coach.architectures.mxnet_components.architecture import MxnetArchitecture +from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, VectorEmbedder +from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead +from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware +from rl_coach.architectures.mxnet_components import utils +from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType +from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace + + +class GeneralMxnetNetwork(MxnetArchitecture): + """ + A generalized version of all possible networks implemented using mxnet. + """ + def __init__(self, + agent_parameters: AgentParameters, + spaces: SpacesDefinition, + name: str, + global_network=None, + network_is_local: bool=True, + network_is_trainable: bool=False): + """ + :param agent_parameters: the agent parameters + :param spaces: the spaces definition of the agent + :param name: the name of the network + :param global_network: the global network replica that is shared between all the workers + :param network_is_local: is the network global (shared between workers) or local (dedicated to the worker) + :param network_is_trainable: is the network trainable (we can apply gradients on it) + """ + self.network_wrapper_name = name.split('/')[0] + self.network_parameters = agent_parameters.network_wrappers[self.network_wrapper_name] + if self.network_parameters.use_separate_networks_per_head: + self.num_heads_per_network = 1 + self.num_networks = len(self.network_parameters.heads_parameters) + else: + self.num_heads_per_network = len(self.network_parameters.heads_parameters) + self.num_networks = 1 + + super().__init__(agent_parameters, spaces, name, global_network, + network_is_local, network_is_trainable) + + def construct_model(self): + # validate the configuration + if len(self.network_parameters.input_embedders_parameters) == 0: + raise ValueError("At least one input type should be defined") + + if len(self.network_parameters.heads_parameters) == 0: + raise ValueError("At least one output type should be defined") + + if self.network_parameters.middleware_parameters is None: + raise ValueError("Exactly one middleware type should be defined") + + self.model = GeneralModel( + num_networks=self.num_networks, + num_heads_per_network=self.num_heads_per_network, + network_is_local=self.network_is_local, + network_name=self.network_wrapper_name, + agent_parameters=self.ap, + network_parameters=self.network_parameters, + spaces=self.spaces) + + self.losses = self.model.losses() + + # Learning rate + lr_scheduler = None + if self.network_parameters.learning_rate_decay_rate != 0: + lr_scheduler = mx.lr_scheduler.FactorScheduler( + step=self.network_parameters.learning_rate_decay_steps, + factor=self.network_parameters.learning_rate_decay_rate) + + # Optimizer + # FIXME Does this code for distributed training make sense? + if self.distributed_training and self.network_is_local and self.network_parameters.shared_optimizer: + # distributed training + is a local network + optimizer shared -> take the global optimizer + self.optimizer = self.global_network.optimizer + elif (self.distributed_training and self.network_is_local and not self.network_parameters.shared_optimizer)\ + or self.network_parameters.shared_optimizer or not self.distributed_training: + + if self.network_parameters.optimizer_type == 'Adam': + self.optimizer = mx.optimizer.Adam( + learning_rate=self.network_parameters.learning_rate, + beta1=self.network_parameters.adam_optimizer_beta1, + beta2=self.network_parameters.adam_optimizer_beta2, + epsilon=self.network_parameters.optimizer_epsilon, + lr_scheduler=lr_scheduler) + elif self.network_parameters.optimizer_type == 'RMSProp': + self.optimizer = mx.optimizer.RMSProp( + learning_rate=self.network_parameters.learning_rate, + gamma1=self.network_parameters.rms_prop_optimizer_decay, + epsilon=self.network_parameters.optimizer_epsilon, + lr_scheduler=lr_scheduler) + elif self.network_parameters.optimizer_type == 'LBFGS': + raise NotImplementedError('LBFGS optimizer not implemented') + else: + raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type)) + + @property + def output_heads(self): + return self.model.output_heads + + +def _get_activation(activation_function_string: str): + """ + Map the activation function from a string to the mxnet framework equivalent + :param activation_function_string: the type of the activation function + :return: mxnet activation function string + """ + return utils.get_mxnet_activation_name(activation_function_string) + + +def _sanitize_activation(params: Union[InputEmbedderParameters, MiddlewareParameters, HeadParameters]) ->\ + Union[InputEmbedderParameters, MiddlewareParameters, HeadParameters]: + """ + Change activation function to the mxnet specific value + :param params: any parameter that has activation_function property + :return: copy of params with activation function correctly set + """ + params_copy = copy.copy(params) + params_copy.activation_function = _get_activation(params.activation_function) + return params_copy + + +def _get_input_embedder(spaces: SpacesDefinition, + input_name: str, + embedder_params: InputEmbedderParameters) -> ModuleType: + """ + Given an input embedder parameters class, creates the input embedder and returns it + :param input_name: the name of the input to the embedder (used for retrieving the shape). The input should + be a value within the state or the action. + :param embedder_params: the parameters of the class of the embedder + :return: the embedder instance + """ + allowed_inputs = copy.copy(spaces.state.sub_spaces) + allowed_inputs["action"] = copy.copy(spaces.action) + allowed_inputs["goal"] = copy.copy(spaces.goal) + + if input_name not in allowed_inputs.keys(): + raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}" + .format(input_name, allowed_inputs.keys())) + + type = "vector" + if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): + type = "image" + + def sanitize_params(params: InputEmbedderParameters): + params_copy = _sanitize_activation(params) + # params_copy.input_rescaling = params_copy.input_rescaling[type] + # params_copy.input_offset = params_copy.input_offset[type] + params_copy.name = input_name + return params_copy + + embedder_params = sanitize_params(embedder_params) + if type == 'vector': + module = VectorEmbedder(embedder_params) + elif type == 'image': + module = ImageEmbedder(embedder_params) + else: + raise KeyError('Unsupported embedder type: {}'.format(type)) + return module + + +def _get_middleware(middleware_params: MiddlewareParameters) -> ModuleType: + """ + Given a middleware type, creates the middleware and returns it + :param middleware_params: the paramaeters of the middleware class + :return: the middleware instance + """ + middleware_params = _sanitize_activation(middleware_params) + if isinstance(middleware_params, FCMiddlewareParameters): + module = FCMiddleware(middleware_params) + elif isinstance(middleware_params, LSTMMiddlewareParameters): + module = LSTMMiddleware(middleware_params) + else: + raise KeyError('Unsupported middleware type: {}'.format(type(middleware_params))) + + return module + + +def _get_output_head( + head_params: HeadParameters, + head_idx: int, + head_type_index: int, + agent_params: AgentParameters, + spaces: SpacesDefinition, + network_name: str, + is_local: bool) -> Head: + """ + Given a head type, creates the head and returns it + :param head_params: the parameters of the head to create + :param head_idx: the head index + :param head_type_index: the head type index (same index if head_param.num_output_head_copies>0) + :param agent_params: agent parameters + :param spaces: state and action space definitions + :param network_name: name of the network + :param is_local: + :return: head block + """ + head_params = _sanitize_activation(head_params) + if isinstance(head_params, PPOHeadParameters): + module = PPOHead( + agent_parameters=agent_params, + spaces=spaces, + network_name=network_name, + head_type_idx=head_type_index, + loss_weight=head_params.loss_weight, + is_local=is_local, + activation_function=head_params.activation_function, + dense_layer=head_params.dense_layer) + elif isinstance(head_params, VHeadParameters): + module = VHead( + agent_parameters=agent_params, + spaces=spaces, + network_name=network_name, + head_type_idx=head_type_index, + loss_weight=head_params.loss_weight, + is_local=is_local, + activation_function=head_params.activation_function, + dense_layer=head_params.dense_layer) + elif isinstance(head_params, PPOVHeadParameters): + module = PPOVHead( + agent_parameters=agent_params, + spaces=spaces, + network_name=network_name, + head_type_idx=head_type_index, + loss_weight=head_params.loss_weight, + is_local=is_local, + activation_function=head_params.activation_function, + dense_layer=head_params.dense_layer) + elif isinstance(head_params, QHeadParameters): + module = QHead( + agent_parameters=agent_params, + spaces=spaces, + network_name=network_name, + head_type_idx=head_type_index, + loss_weight=head_params.loss_weight, + is_local=is_local, + activation_function=head_params.activation_function, + dense_layer=head_params.dense_layer) + else: + raise KeyError('Unsupported head type: {}'.format(type(head_params))) + + return module + + +class ScaledGradHead(HybridBlock): + """ + Wrapper block for applying gradient scaling to input before feeding the head network + """ + def __init__(self, + head_index: int, + head_type_index: int, + network_name: str, + spaces: SpacesDefinition, + network_is_local: bool, + agent_params: AgentParameters, + head_params: HeadParameters) -> None: + """ + :param head_idx: the head index + :param head_type_index: the head type index (same index if head_param.num_output_head_copies>0) + :param network_name: name of the network + :param spaces: state and action space definitions + :param network_is_local: whether network is local + :param agent_params: agent parameters + :param head_params: head parameters + """ + super(ScaledGradHead, self).__init__() + + head_params = _sanitize_activation(head_params) + with self.name_scope(): + self.head = _get_output_head( + head_params=head_params, + head_idx=head_index, + head_type_index=head_type_index, + agent_params=agent_params, + spaces=spaces, + network_name=network_name, + is_local=network_is_local) + self.gradient_rescaler = self.params.get_constant( + name='gradient_rescaler', + value=np.array([float(head_params.rescale_gradient_from_head_by_factor)])) + # self.gradient_rescaler = self.params.get( + # name='gradient_rescaler', + # shape=(1,), + # init=mx.init.Constant(float(head_params.rescale_gradient_from_head_by_factor))) + + def hybrid_forward(self, + F: ModuleType, + x: Union[NDArray, Symbol], + gradient_rescaler: Union[NDArray, Symbol]) -> Tuple[Union[NDArray, Symbol], ...]: + """ Overrides gluon.HybridBlock.hybrid_forward + :param nd or sym F: ndarray or symbol module + :param x: head input + :param gradient_rescaler: gradient rescaler for partial blocking of gradient + :return: head output + """ + grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x) + out = self.head(grad_scaled_x) + return out + + +class SingleModel(HybridBlock): + """ + Block that connects a single embedder, with middleware and one to multiple heads + """ + def __init__(self, + network_is_local: bool, + network_name: str, + agent_parameters: AgentParameters, + in_emb_param_dict: {str: InputEmbedderParameters}, + embedding_merger_type: EmbeddingMergerType, + middleware_param: MiddlewareParameters, + head_param_list: [HeadParameters], + head_type_idx_start: int, + spaces: SpacesDefinition, + *args, **kwargs): + """ + :param network_is_local: True if network is local + :param network_name: name of the network + :param agent_parameters: agent parameters + :param in_emb_param_dict: dictionary of embedder name to embedding parameters + :param embedding_merger_type: type of merging output of embedders: concatenate or sum + :param middleware_param: middleware parameters + :param head_param_list: list of head parameters, one per head type + :param head_type_idx_start: start index for head type index counting + :param spaces: state and action space definition + """ + super(SingleModel, self).__init__(*args, **kwargs) + + self._embedding_merger_type = embedding_merger_type + self._input_embedders = list() # type: List[HybridBlock] + self._output_heads = list() # type: List[ScaledGradHead] + + with self.name_scope(): + for input_name in sorted(in_emb_param_dict): + input_type = in_emb_param_dict[input_name] + input_embedder = _get_input_embedder(spaces, input_name, input_type) + self.register_child(input_embedder) + self._input_embedders.append(input_embedder) + + self.middleware = _get_middleware(middleware_param) + + for i, head_param in enumerate(head_param_list): + for head_copy_idx in range(head_param.num_output_head_copies): + # create output head and add it to the output heads list + output_head = ScaledGradHead( + head_index=(head_type_idx_start + i) * head_param.num_output_head_copies + head_copy_idx, + head_type_index=head_type_idx_start + i, + network_name=network_name, + spaces=spaces, + network_is_local=network_is_local, + agent_params=agent_parameters, + head_params=head_param) + self.register_child(output_head) + self._output_heads.append(output_head) + + def hybrid_forward(self, F, *inputs: Union[NDArray, Symbol]) -> Tuple[Union[NDArray, Symbol], ...]: + """ Overrides gluon.HybridBlock.hybrid_forward + :param nd or sym F: ndarray or symbol block + :param inputs: model inputs, one for each embedder + :return: head outputs in a tuple + """ + # Input Embeddings + state_embedding = list() + for input, embedder in zip(inputs, self._input_embedders): + state_embedding.append(embedder(input)) + + # Merger + if len(state_embedding) == 1: + state_embedding = state_embedding[0] + else: + if self._embedding_merger_type == EmbeddingMergerType.Concat: + state_embedding = F.concat(*state_embedding, dim=1, name='merger') # NC or NCHW layout + elif self._embedding_merger_type == EmbeddingMergerType.Sum: + state_embedding = F.add_n(*state_embedding, name='merger') + + # Middleware + state_embedding = self.middleware(state_embedding) + + # Head + outputs = tuple() + for head in self._output_heads: + outputs += (head(state_embedding),) + + return outputs + + @property + def input_embedders(self) -> List[HybridBlock]: + """ + :return: list of input embedders + """ + return self._input_embedders + + @property + def output_heads(self) -> List[Head]: + """ + :return: list of output heads + """ + return [h.head for h in self._output_heads] + + +class GeneralModel(HybridBlock): + """ + Block that creates multiple single models + """ + def __init__(self, + num_networks: int, + num_heads_per_network: int, + network_is_local: bool, + network_name: str, + agent_parameters: AgentParameters, + network_parameters: NetworkParameters, + spaces: SpacesDefinition, + *args, **kwargs): + """ + :param num_networks: number of networks to create + :param num_heads_per_network: number of heads per network to create + :param network_is_local: True if network is local + :param network_name: name of the network + :param agent_parameters: agent parameters + :param network_parameters: network parameters + :param spaces: state and action space definitions + """ + super(GeneralModel, self).__init__(*args, **kwargs) + + with self.name_scope(): + self.nets = list() + for network_idx in range(num_networks): + head_type_idx_start = network_idx * num_heads_per_network + head_type_idx_end = head_type_idx_start + num_heads_per_network + net = SingleModel( + head_type_idx_start=head_type_idx_start, + network_name=network_name, + network_is_local=network_is_local, + agent_parameters=agent_parameters, + in_emb_param_dict=network_parameters.input_embedders_parameters, + embedding_merger_type=network_parameters.embedding_merger_type, + middleware_param=network_parameters.middleware_parameters, + head_param_list=network_parameters.heads_parameters[head_type_idx_start:head_type_idx_end], + spaces=spaces) + self.register_child(net) + self.nets.append(net) + + def hybrid_forward(self, F, *inputs): + """ Overrides gluon.HybridBlock.hybrid_forward + :param nd or sym F: ndarray or symbol block + :param inputs: model inputs, one for each embedder. Passed to all networks. + :return: head outputs in a tuple + """ + outputs = tuple() + for net in self.nets: + out = net(*inputs) + outputs += out + return outputs + + @property + def output_heads(self) -> List[Head]: + """ Return all heads in a single list + Note: There is a one-to-one mapping between output_heads and losses + :return: list of heads + """ + return list(chain.from_iterable(net.output_heads for net in self.nets)) + + def losses(self) -> List[HeadLoss]: + """ Construct loss blocks for network training + Note: There is a one-to-one mapping between output_heads and losses + :return: list of loss blocks + """ + return [h.loss() for net in self.nets for h in net.output_heads] diff --git a/rl_coach/architectures/mxnet_components/heads/__init__.py b/rl_coach/architectures/mxnet_components/heads/__init__.py new file mode 100644 index 0000000..47c1878 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/heads/__init__.py @@ -0,0 +1,14 @@ +from .head import Head, HeadLoss +from .q_head import QHead +from .ppo_head import PPOHead +from .ppo_v_head import PPOVHead +from .v_head import VHead + +__all__ = [ + 'Head', + 'HeadLoss', + 'QHead', + 'PPOHead', + 'PPOVHead', + 'VHead' +] diff --git a/rl_coach/architectures/mxnet_components/heads/head.py b/rl_coach/architectures/mxnet_components/heads/head.py new file mode 100644 index 0000000..4a83152 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/heads/head.py @@ -0,0 +1,181 @@ +from typing import Dict, List, Union, Tuple + +from mxnet.gluon import nn, loss +from mxnet.ndarray import NDArray +from mxnet.symbol import Symbol +from rl_coach.base_parameters import AgentParameters +from rl_coach.spaces import SpacesDefinition + + +LOSS_OUT_TYPE_LOSS = 'loss' +LOSS_OUT_TYPE_REGULARIZATION = 'regularization' + + +class LossInputSchema(object): + """ + Helper class to contain schema for loss hybrid_forward input + """ + def __init__(self, head_outputs: List[str], agent_inputs: List[str], targets: List[str]): + """ + :param head_outputs: list of argument names in hybrid_forward that are outputs of the head. + The order and number MUST MATCH the output from the head. + :param agent_inputs: list of argument names in hybrid_forward that are inputs from the agent. + The order and number MUST MATCH `output__` for this head. + :param targets: list of argument names in hybrid_forward that are targets for the loss. + The order and number MUST MATCH targets passed from the agent. + """ + self._head_outputs = head_outputs + self._agent_inputs = agent_inputs + self._targets = targets + + @property + def head_outputs(self): + return self._head_outputs + + @property + def agent_inputs(self): + return self._agent_inputs + + @property + def targets(self): + return self._targets + + +class HeadLoss(loss.Loss): + """ + ABC for loss functions of each head. Child class must implement input_schema() and loss_forward() + """ + def __init__(self, *args, **kwargs): + super(HeadLoss, self).__init__(*args, **kwargs) + self._output_schema = None # type: List[str] + + @property + def input_schema(self) -> LossInputSchema: + """ + :return: schema for input of hybrid_forward. Read docstring for LossInputSchema for details. + """ + raise NotImplementedError + + @property + def output_schema(self) -> List[str]: + """ + :return: schema for output of hybrid_forward. Must contain 'loss' and 'regularization' keys at least once. + The order and total number must match that of returned values from the loss. 'loss' and 'regularization' + are special keys. Any other string is treated as auxiliary outputs and must include match auxiliary + fetch names returned by the head. + """ + return self._output_schema + + def forward(self, *args): + """ + Override forward() so that number of outputs can be checked against the schema + """ + outputs = super(HeadLoss, self).forward(*args) + if isinstance(outputs, tuple) or isinstance(outputs, list): + num_outputs = len(outputs) + else: + assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol) + num_outputs = 1 + assert num_outputs == len(self.output_schema), "Number of outputs don't match schema ({} != {})".format( + num_outputs, len(self.output_schema)) + return outputs + + def _loss_output(self, outputs: List[Tuple[Union[NDArray, Symbol], str]]): + """ + Must be called on the output from hybrid_forward(). + Saves the returned output as the schema and returns output values in a list + :return: list of output values + """ + output_schema = [o[1] for o in outputs] + assert self._output_schema is None or self._output_schema == output_schema + self._output_schema = output_schema + return tuple(o[0] for o in outputs) + + def hybrid_forward(self, F, x, *args, **kwargs): + """ + Passes the cal to loss_forward() and constructs output schema from its output by calling loss_output() + """ + return self._loss_output(self.loss_forward(F, x, *args, **kwargs)) + + def loss_forward(self, F, x, *args, **kwargs) -> List[Tuple[Union[NDArray, Symbol], str]]: + """ + Similar to hybrid_forward, but returns list of (NDArray, type_str) + """ + raise NotImplementedError + + +class Head(nn.HybridBlock): + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, + network_name: str, head_type_idx: int=0, loss_weight: float=1., is_local: bool=True, + activation_function: str='relu', dense_layer: None=None): + """ + A head is the final part of the network. It takes the embedding from the middleware embedder and passes it + through a neural network to produce the output of the network. There can be multiple heads in a network, and + each one has an assigned loss function. The heads are algorithm dependent. + + :param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon + and beta_entropy. + :param spaces: containing action spaces used for defining size of network output. + :param network_name: name of head network. currently unused. + :param head_type_idx: index of head network. currently unused. + :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param is_local: flag to denote if network is local. currently unused. + :param activation_function: activation function to use between layers. currently unused. + :param dense_layer: type of dense layer to use in network. currently unused. + """ + super(Head, self).__init__() + self.head_type_idx = head_type_idx + self.network_name = network_name + self.loss_weight = loss_weight + self.is_local = is_local + self.ap = agent_parameters + self.spaces = spaces + self.return_type = None + self.activation_function = activation_function + self.dense_layer = dense_layer + self._num_outputs = None + + def loss(self) -> HeadLoss: + """ + Returns loss block to be used for specific head implementation. + + :return: loss block (can be called as function) for outputs returned by the head network. + """ + raise NotImplementedError() + + @property + def num_outputs(self): + """ Returns number of outputs that forward() call will return + + :return: + """ + assert self._num_outputs is not None, 'must call forward() once to configure number of outputs' + return self._num_outputs + + def forward(self, *args): + """ + Override forward() so that number of outputs can be automatically set + """ + outputs = super(Head, self).forward(*args) + if isinstance(outputs, tuple): + num_outputs = len(outputs) + else: + assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol) + num_outputs = 1 + if self._num_outputs is None: + self._num_outputs = num_outputs + else: + assert self._num_outputs == num_outputs, 'Number of outputs cannot change ({} != {})'.format( + self._num_outputs, num_outputs) + assert self._num_outputs == len(self.loss().input_schema.head_outputs) + return outputs + + def hybrid_forward(self, F, x, *args, **kwargs): + """ + Used for forward pass through head network. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: middleware state representation, of shape (batch_size, in_channels). + :return: final output of network, that will be used in loss calculations. + """ + raise NotImplementedError() diff --git a/rl_coach/architectures/mxnet_components/heads/ppo_head.py b/rl_coach/architectures/mxnet_components/heads/ppo_head.py new file mode 100644 index 0000000..01b2192 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/heads/ppo_head.py @@ -0,0 +1,669 @@ +from typing import List, Tuple, Union +from types import ModuleType + +import math +import mxnet as mx +from mxnet.gluon import nn +from rl_coach.base_parameters import AgentParameters +from rl_coach.core_types import ActionProbabilities +from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace +from rl_coach.utils import eps +from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema +from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION +from rl_coach.architectures.mxnet_components.utils import hybrid_clip + + +LOSS_OUT_TYPE_KL = 'kl_divergence' +LOSS_OUT_TYPE_ENTROPY = 'entropy' +LOSS_OUT_TYPE_LIKELIHOOD_RATIO = 'likelihood_ratio' +LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO = 'clipped_likelihood_ratio' + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class MultivariateNormalDist: + def __init__(self, + num_var: int, + mean: nd_sym_type, + sigma: nd_sym_type, + F: ModuleType=mx.nd) -> None: + """ + Distribution object for Multivariate Normal. Works with batches. + Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step, + mean, sigma and data for log_prob must all include a time_step dimension. + + :param num_var: number of variables in distribution + :param mean: mean for each variable, + of shape (num_var) or + of shape (batch_size, num_var) or + of shape (batch_size, time_step, num_var). + :param sigma: covariance matrix, + of shape (num_var, num_var) or + of shape (batch_size, num_var, num_var) or + of shape (batch_size, time_step, num_var, num_var). + :param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized). + """ + self.num_var = num_var + self.mean = mean + self.sigma = sigma + self.F = F + + def inverse_using_cholesky(self, matrix: nd_sym_type) -> nd_sym_type: + """ + Calculate inverses for a batch of matrices using Cholesky decomposition method. + + :param matrix: matrix (or matrices) to invert, + of shape (num_var, num_var) or + of shape (batch_size, num_var, num_var) or + of shape (batch_size, time_step, num_var, num_var). + :return: inverted matrix (or matrices), + of shape (num_var, num_var) or + of shape (batch_size, num_var, num_var) or + of shape (batch_size, time_step, num_var, num_var). + """ + cholesky_factor = self.F.linalg.potrf(matrix) + return self.F.linalg.potri(cholesky_factor) + + def log_det(self, matrix: nd_sym_type) -> nd_sym_type: + """ + Calculate log of the determinant for a batch of matrices using Cholesky decomposition method. + + :param matrix: matrix (or matrices) to invert, + of shape (num_var, num_var) or + of shape (batch_size, num_var, num_var) or + of shape (batch_size, time_step, num_var, num_var). + :return: inverted matrix (or matrices), + of shape (num_var, num_var) or + of shape (batch_size, num_var, num_var) or + of shape (batch_size, time_step, num_var, num_var). + """ + cholesky_factor = self.F.linalg.potrf(matrix) + return 2 * self.F.linalg.sumlogdiag(cholesky_factor) + + def log_prob(self, x: nd_sym_type) -> nd_sym_type: + """ + Calculate the log probability of data given the current distribution. + + See http://www.notenoughthoughts.net/posts/normal-log-likelihood-gradient.html + and https://discuss.mxnet.io/t/multivariate-gaussian-log-density-operator/1169/7 + + :param x: input data, + of shape (num_var) or + of shape (batch_size, num_var) or + of shape (batch_size, time_step, num_var). + :return: log_probability, + of shape (1) or + of shape (batch_size) or + of shape (batch_size, time_step). + """ + a = (self.num_var / 2) * math.log(2 * math.pi) + log_det_sigma = self.log_det(self.sigma) + b = (1 / 2) * log_det_sigma + sigma_inv = self.inverse_using_cholesky(self.sigma) + # deviation from mean, and dev_t is equivalent to transpose on last two dims. + dev = (x - self.mean).expand_dims(-1) + dev_t = (x - self.mean).expand_dims(-2) + + # since batch_dot only works with ndarrays with ndim of 3, + # and we could have ndarrays with ndim of 4, + # we flatten batch_size and time_step into single dim. + dev_flat = dev.reshape(shape=(-1, 0, 0), reverse=1) + sigma_inv_flat = sigma_inv.reshape(shape=(-1, 0, 0), reverse=1) + dev_t_flat = dev_t.reshape(shape=(-1, 0, 0), reverse=1) + c = (1 / 2) * self.F.batch_dot(self.F.batch_dot(dev_t_flat, sigma_inv_flat), dev_flat) + # and now reshape back to (batch_size, time_step) if required. + c = c.reshape_like(b) + + log_likelihood = -a - b - c + return log_likelihood + + def entropy(self) -> nd_sym_type: + """ + Calculate entropy of current distribution. + + See http://www.nowozin.net/sebastian/blog/the-entropy-of-a-normal-distribution.html + :return: entropy, + of shape (1) or + of shape (batch_size) or + of shape (batch_size, time_step). + """ + # todo: check if differential entropy is correct + log_det_sigma = self.log_det(self.sigma) + return (self.num_var / 2) + ((self.num_var / 2) * math.log(2 * math.pi)) + ((1 / 2) * log_det_sigma) + + def kl_div(self, alt_dist) -> nd_sym_type: + """ + Calculated KL-Divergence with another MultivariateNormalDist distribution + See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence + Specifically https://wikimedia.org/api/rest_v1/media/math/render/svg/a3bf3b4917bd1fcb8be48d6d6139e2e387bdc7d3 + + :param alt_dist: alternative distribution used for kl divergence calculation + :type alt_dist: MultivariateNormalDist + :return: KL-Divergence, of shape (1,) + """ + sigma_a_inv = self.F.linalg.potri(self.F.linalg.potrf(self.sigma)) + sigma_b_inv = self.F.linalg.potri(self.F.linalg.potrf(alt_dist.sigma)) + term1a = mx.nd.batch_dot(sigma_b_inv, self.sigma) + # sum of diagonal for batch of matrices + term1 = (self.F.eye(self.num_var).broadcast_like(term1a) * term1a).sum(axis=-1).sum(axis=-1) + mean_diff = (alt_dist.mean - self.mean).expand_dims(-1) + mean_diff_t = (alt_dist.mean - self.mean).expand_dims(-2) + term2 = self.F.batch_dot(self.F.batch_dot(mean_diff_t, sigma_b_inv), mean_diff).reshape_like(term1) + term3 = (2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(alt_dist.sigma))) -\ + (2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(self.sigma))) + return 0.5 * (term1 + term2 - self.num_var + term3) + + + +class CategoricalDist: + def __init__(self, n_classes: int, probs: nd_sym_type, F: ModuleType=mx.nd) -> None: + """ + Distribution object for Categorical data. + Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step, + mean, sigma and data for log_prob must all include a time_step dimension. + + :param n_classes: number of classes in distribution + :param probs: probabilities for each class, + of shape (n_classes), + of shape (batch_size, n_classes) or + of shape (batch_size, time_step, n_classes) + :param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized). + """ + self.n_classes = n_classes + self.probs = probs + self.F = F + + + def log_prob(self, actions: nd_sym_type) -> nd_sym_type: + """ + Calculate the log probability of data given the current distribution. + + :param actions: actions, with int8 data type, + of shape (1) if probs was (n_classes), + of shape (batch_size) if probs was (batch_size, n_classes) and + of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes) + :return: log_probability, + of shape (1) if probs was (n_classes), + of shape (batch_size) if probs was (batch_size, n_classes) and + of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes) + """ + action_mask = actions.one_hot(depth=self.n_classes) + action_probs = (self.probs * action_mask).sum(axis=-1) + return action_probs.log() + + def entropy(self) -> nd_sym_type: + """ + Calculate entropy of current distribution. + + :return: entropy, + of shape (1) if probs was (n_classes), + of shape (batch_size) if probs was (batch_size, n_classes) and + of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes) + """ + # todo: look into numerical stability + return -(self.probs.log()*self.probs).sum(axis=-1) + + def kl_div(self, alt_dist) -> nd_sym_type: + """ + Calculated KL-Divergence with another Categorical distribution + + :param alt_dist: alternative distribution used for kl divergence calculation + :type alt_dist: CategoricalDist + :return: KL-Divergence + """ + logits_a = self.probs.clip(a_min=eps, a_max=1 - eps).log() + logits_b = alt_dist.probs.clip(a_min=eps, a_max=1 - eps).log() + t = self.probs * (logits_a - logits_b) + t = self.F.where(condition=(alt_dist.probs == 0), x=self.F.ones_like(alt_dist.probs) * math.inf, y=t) + t = self.F.where(condition=(self.probs == 0), x=self.F.zeros_like(self.probs), y=t) + return t.sum(axis=-1) + + +class DiscretePPOHead(nn.HybridBlock): + def __init__(self, num_actions: int) -> None: + """ + Head block for Discrete Proximal Policy Optimization, to calculate probabilities for each action given + middleware representation of the environment state. + + :param num_actions: number of actions in action space. + """ + super(DiscretePPOHead, self).__init__() + with self.name_scope(): + self.dense = nn.Dense(units=num_actions, flatten=False) + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type: + """ + Used for forward pass through head network. + + :param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized). + :param x: middleware state representation, + of shape (batch_size, in_channels) or + of shape (batch_size, time_step, in_channels). + :return: batch of probabilities for each action, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + """ + policy_values = self.dense(x) + policy_probs = F.softmax(policy_values) + return policy_probs + + +class ContinuousPPOHead(nn.HybridBlock): + def __init__(self, num_actions: int) -> None: + """ + Head block for Continuous Proximal Policy Optimization, to calculate probabilities for each action given + middleware representation of the environment state. + + :param num_actions: number of actions in action space. + """ + super(ContinuousPPOHead, self).__init__() + with self.name_scope(): + # todo: change initialization strategy + self.dense = nn.Dense(units=num_actions, flatten=False) + # all samples (across batch, and time step) share the same covariance, which is learnt, + # but since we assume the action probability variables are independent, + # only the diagonal entries of the covariance matrix are specified. + self.log_std = self.params.get('log_std', + shape=num_actions, + init=mx.init.Zero(), + allow_deferred_init=True) + # todo: is_local? + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type, log_std: nd_sym_type) -> List[nd_sym_type]: + """ + Used for forward pass through head network. + + :param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized). + :param x: middleware state representation, + of shape (batch_size, in_channels) or + of shape (batch_size, time_step, in_channels). + :return: batch of probabilities for each action, + of shape (batch_size, action_mean) or + of shape (batch_size, time_step, action_mean). + """ + policy_means = self.dense(x) + policy_std = log_std.exp() + return [policy_means, policy_std] + + +class ClippedPPOLossDiscrete(HeadLoss): + def __init__(self, + num_actions: int, + clip_likelihood_ratio_using_epsilon: float, + beta: float=0, + use_kl_regularization: bool=False, + initial_kl_coefficient: float=1, + kl_cutoff: float=0, + high_kl_penalty_coefficient: float=1, + weight: float=1, + batch_axis: int=0) -> None: + """ + Loss for discrete version of Clipped PPO. + + :param num_actions: number of actions in action space. + :param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping. + :param beta: loss coefficient applied to entropy + :param use_kl_regularization: option to add kl divergence loss + :param initial_kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient). + :param kl_cutoff: threshold for using high_kl_penalty_coefficient + :param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff + :param weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation. + """ + super(ClippedPPOLossDiscrete, self).__init__(weight=weight, batch_axis=batch_axis) + self.weight = weight + self.num_actions = num_actions + self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon + self.beta = beta + self.use_kl_regularization = use_kl_regularization + self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0 + self.kl_coefficient = self.params.get('kl_coefficient', + shape=(1,), + init=mx.init.Constant([initial_kl_coefficient,]), + differentiable=False) + self.kl_cutoff = kl_cutoff + self.high_kl_penalty_coefficient = high_kl_penalty_coefficient + + @property + def input_schema(self) -> LossInputSchema: + return LossInputSchema( + head_outputs=['new_policy_probs'], + agent_inputs=['actions', 'old_policy_probs', 'clip_param_rescaler'], + targets=['advantages'] + ) + + def loss_forward(self, + F: ModuleType, + new_policy_probs: nd_sym_type, + actions: nd_sym_type, + old_policy_probs: nd_sym_type, + clip_param_rescaler: nd_sym_type, + advantages: nd_sym_type, + kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]: + """ + Used for forward pass through loss computations. + Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step, + new_policy_probs, old_policy_probs, actions and advantages all must include a time_step dimension. + + NOTE: order of input arguments MUST NOT CHANGE because it matches the order + parameters are passed in ppo_agent:train_network() + + :param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized). + :param new_policy_probs: action probabilities predicted by DiscretePPOHead network, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + :param old_policy_probs: action probabilities for previous policy, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + :param actions: true actions taken during rollout, + of shape (batch_size) or + of shape (batch_size, time_step). + :param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping. + :param advantages: change in state value after taking action (a.k.a advantage) + of shape (batch_size) or + of shape (batch_size, time_step). + :param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient). + :return: loss, of shape (batch_size). + """ + + old_policy_dist = CategoricalDist(self.num_actions, old_policy_probs, F=F) + action_probs_wrt_old_policy = old_policy_dist.log_prob(actions) + + new_policy_dist = CategoricalDist(self.num_actions, new_policy_probs, F=F) + action_probs_wrt_new_policy = new_policy_dist.log_prob(actions) + + entropy_loss = - self.beta * new_policy_dist.entropy().mean() + + if self.use_kl_regularization: + kl_div = old_policy_dist.kl_div(new_policy_dist).mean() + weighted_kl_div = kl_coefficient * kl_div + high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square() + weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div + kl_div_loss = weighted_kl_div + weighted_high_kl_div + else: + kl_div_loss = F.zeros(shape=(1,)) + + # working with log probs, so minus first, then exponential (same as division) + likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp() + + if self.clip_likelihood_ratio_using_epsilon is not None: + # clipping of likelihood ratio + min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler + max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler + + # can't use F.clip (with variable clipping bounds), hence custom implementation + clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value) + + # lower bound of original, and clipped versions or each scaled advantage + # element-wise min between the two ndarrays + unclipped_scaled_advantages = likelihood_ratio * advantages + clipped_scaled_advantages = clipped_likelihood_ratio * advantages + scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0) + else: + scaled_advantages = likelihood_ratio * advantages + clipped_likelihood_ratio = F.zeros_like(likelihood_ratio) + + # for each batch, calculate expectation of scaled_advantages across time steps, + # but want code to work with data without time step too, so reshape to add timestep if doesn't exist. + scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1)) + expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1) + # want to maximize expected_scaled_advantages, add minus so can minimize. + surrogate_loss = (-expected_scaled_advantages * self.weight).mean() + + return [ + (surrogate_loss, LOSS_OUT_TYPE_LOSS), + (entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION), + (kl_div_loss, LOSS_OUT_TYPE_KL), + (entropy_loss, LOSS_OUT_TYPE_ENTROPY), + (likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO), + (clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO) + ] + + +class ClippedPPOLossContinuous(HeadLoss): + def __init__(self, + num_actions: int, + clip_likelihood_ratio_using_epsilon: float, + beta: float=0, + use_kl_regularization: bool=False, + initial_kl_coefficient: float=1, + kl_cutoff: float=0, + high_kl_penalty_coefficient: float=1, + weight: float=1, + batch_axis: int=0): + """ + Loss for continuous version of Clipped PPO. + + :param num_actions: number of actions in action space. + :param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping. + :param beta: loss coefficient applied to entropy + :param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation. + :param use_kl_regularization: option to add kl divergence loss + :param initial_kl_coefficient: initial loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient). + :param kl_cutoff: threshold for using high_kl_penalty_coefficient + :param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff + :param weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation. + """ + super(ClippedPPOLossContinuous, self).__init__(weight=weight, batch_axis=batch_axis) + self.weight = weight + self.num_actions = num_actions + self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon + self.beta = beta + self.use_kl_regularization = use_kl_regularization + self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0 + self.kl_coefficient = self.params.get('kl_coefficient', + shape=(1,), + init=mx.init.Constant([initial_kl_coefficient,]), + differentiable=False) + self.kl_cutoff = kl_cutoff + self.high_kl_penalty_coefficient = high_kl_penalty_coefficient + + @property + def input_schema(self) -> LossInputSchema: + return LossInputSchema( + head_outputs=['new_policy_means','new_policy_stds'], + agent_inputs=['actions', 'old_policy_means', 'old_policy_stds', 'clip_param_rescaler'], + targets=['advantages'] + ) + + def loss_forward(self, + F: ModuleType, + new_policy_means: nd_sym_type, + new_policy_stds: nd_sym_type, + actions: nd_sym_type, + old_policy_means: nd_sym_type, + old_policy_stds: nd_sym_type, + clip_param_rescaler: nd_sym_type, + advantages: nd_sym_type, + kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]: + """ + Used for forward pass through loss computations. + Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step, + new_policy_means, old_policy_means, actions and advantages all must include a time_step dimension. + + :param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized). + :param new_policy_means: action means predicted by MultivariateNormalDist network, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + :param new_policy_stds: action standard deviation returned by head, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + :param actions: true actions taken during rollout, + of shape (batch_size) or + of shape (batch_size, time_step). + :param old_policy_means: action means for previous policy, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + :param old_policy_stds: action standard deviation returned by head previously, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + :param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping. + :param advantages: change in state value after taking action (a.k.a advantage) + of shape (batch_size) or + of shape (batch_size, time_step). + :param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient). + :return: loss, of shape (batch_size). + """ + old_var = old_policy_stds ** 2 + # sets diagonal in (batch size and time step) covariance matrices + old_covar = mx.nd.eye(N=self.num_actions) * (old_var + eps).broadcast_like(old_policy_means).expand_dims(-2) + old_policy_dist = MultivariateNormalDist(self.num_actions, old_policy_means, old_covar, F=F) + action_probs_wrt_old_policy = old_policy_dist.log_prob(actions) + + new_var = new_policy_stds ** 2 + # sets diagonal in (batch size and time step) covariance matrices + new_covar = mx.nd.eye(N=self.num_actions) * (new_var + eps).broadcast_like(new_policy_means).expand_dims(-2) + new_policy_dist = MultivariateNormalDist(self.num_actions, new_policy_means, new_covar, F=F) + action_probs_wrt_new_policy = new_policy_dist.log_prob(actions) + + entropy_loss = - self.beta * new_policy_dist.entropy().mean() + + if self.use_kl_regularization: + kl_div = old_policy_dist.kl_div(new_policy_dist).mean() + weighted_kl_div = kl_coefficient * kl_div + high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square() + weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div + kl_div_loss = weighted_kl_div + weighted_high_kl_div + else: + kl_div_loss = F.zeros(shape=(1,)) + + # working with log probs, so minus first, then exponential (same as division) + likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp() + + if self.clip_likelihood_ratio_using_epsilon is not None: + # clipping of likelihood ratio + min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler + max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler + + # can't use F.clip (with variable clipping bounds), hence custom implementation + clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value) + + # lower bound of original, and clipped versions or each scaled advantage + # element-wise min between the two ndarrays + unclipped_scaled_advantages = likelihood_ratio * advantages + clipped_scaled_advantages = clipped_likelihood_ratio * advantages + scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0) + else: + scaled_advantages = likelihood_ratio * advantages + clipped_likelihood_ratio = F.zeros_like(likelihood_ratio) + + # for each batch, calculate expectation of scaled_advantages across time steps, + # but want code to work with data without time step too, so reshape to add timestep if doesn't exist. + scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1)) + expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1) + # want to maximize expected_scaled_advantages, add minus so can minimize. + surrogate_loss = (-expected_scaled_advantages * self.weight).mean() + + return [ + (surrogate_loss, LOSS_OUT_TYPE_LOSS), + (entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION), + (kl_div_loss, LOSS_OUT_TYPE_KL), + (entropy_loss, LOSS_OUT_TYPE_ENTROPY), + (likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO), + (clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO) + ] + + +class PPOHead(Head): + def __init__(self, + agent_parameters: AgentParameters, + spaces: SpacesDefinition, + network_name: str, + head_type_idx: int=0, + loss_weight: float=1., + is_local: bool=True, + activation_function: str='tanh', + dense_layer: None=None) -> None: + """ + Head block for Proximal Policy Optimization, to calculate probabilities for each action given middleware + representation of the environment state. + + :param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon + and beta_entropy. + :param spaces: containing action spaces used for defining size of network output. + :param network_name: name of head network. currently unused. + :param head_type_idx: index of head network. currently unused. + :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param is_local: flag to denote if network is local. currently unused. + :param activation_function: activation function to use between layers. currently unused. + :param dense_layer: type of dense layer to use in network. currently unused. + """ + super().__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local, activation_function, + dense_layer=dense_layer) + self.return_type = ActionProbabilities + + self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon + self.beta = agent_parameters.algorithm.beta_entropy + self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization + if self.use_kl_regularization: + self.initial_kl_coefficient = agent_parameters.algorithm.initial_kl_coefficient + self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence + self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient + else: + self.initial_kl_coefficient, self.kl_cutoff, self.high_kl_penalty_coefficient = (None, None, None) + self._loss = [] + + if isinstance(self.spaces.action, DiscreteActionSpace): + self.net = DiscretePPOHead(num_actions=len(self.spaces.action.actions)) + elif isinstance(self.spaces.action, BoxActionSpace): + self.net = ContinuousPPOHead(num_actions=len(self.spaces.action.actions)) + else: + raise ValueError("Only discrete or continuous action spaces are supported for PPO.") + + def hybrid_forward(self, + F: ModuleType, + x: nd_sym_type) -> nd_sym_type: + """ + :param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized). + :param x: middleware embedding + :return: policy parameters/probabilities + """ + return self.net(x) + + def loss(self) -> mx.gluon.loss.Loss: + """ + Specifies loss block to be used for this policy head. + + :return: loss block (can be called as function) for action probabilities returned by this policy network. + """ + if isinstance(self.spaces.action, DiscreteActionSpace): + loss = ClippedPPOLossDiscrete(len(self.spaces.action.actions), + self.clip_likelihood_ratio_using_epsilon, + self.beta, + self.use_kl_regularization, self.initial_kl_coefficient, + self.kl_cutoff, self.high_kl_penalty_coefficient, + self.loss_weight) + elif isinstance(self.spaces.action, BoxActionSpace): + loss = ClippedPPOLossContinuous(len(self.spaces.action.actions), + self.clip_likelihood_ratio_using_epsilon, + self.beta, + self.use_kl_regularization, self.initial_kl_coefficient, + self.kl_cutoff, self.high_kl_penalty_coefficient, + self.loss_weight) + else: + raise ValueError("Only discrete or continuous action spaces are supported for PPO.") + loss.initialize() + # set a property so can assign_kl_coefficient in future, + # make a list, otherwise it would be added as a child of Head Block (due to type check) + self._loss = [loss] + return loss + + @property + def kl_divergence(self): + return self.head_type_idx, LOSS_OUT_TYPE_KL + + @property + def entropy(self): + return self.head_type_idx, LOSS_OUT_TYPE_ENTROPY + + @property + def likelihood_ratio(self): + return self.head_type_idx, LOSS_OUT_TYPE_LIKELIHOOD_RATIO + + @property + def clipped_likelihood_ratio(self): + return self.head_type_idx, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO + + def assign_kl_coefficient(self, kl_coefficient: float) -> None: + self._loss[0].kl_coefficient.set_data(mx.nd.array((kl_coefficient,))) \ No newline at end of file diff --git a/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py b/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py new file mode 100644 index 0000000..5512b90 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py @@ -0,0 +1,123 @@ +from typing import List, Tuple, Union +from types import ModuleType + +import mxnet as mx +from mxnet.gluon import nn +from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema +from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS +from rl_coach.base_parameters import AgentParameters +from rl_coach.core_types import ActionProbabilities +from rl_coach.spaces import SpacesDefinition + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class PPOVHeadLoss(HeadLoss): + def __init__(self, clip_likelihood_ratio_using_epsilon: float, weight: float=1, batch_axis: int=0) -> None: + """ + Loss for PPO Value network. + Schulman implemented this extension in OpenAI baselines for PPO2 + See https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py#L72 + + :param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping. + :param weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation. + """ + super(PPOVHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis) + self.weight = weight + self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon + + @property + def input_schema(self) -> LossInputSchema: + return LossInputSchema( + head_outputs=['new_policy_values'], + agent_inputs=['old_policy_values'], + targets=['target_values'] + ) + + def loss_forward(self, + F: ModuleType, + new_policy_values: nd_sym_type, + old_policy_values: nd_sym_type, + target_values: nd_sym_type) -> List[Tuple[nd_sym_type, str]]: + """ + Used for forward pass through loss computations. + Calculates two losses (L2 and a clipped difference L2 loss) and takes the maximum of the two. + Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step, + new_policy_values, old_policy_values and target_values all must include a time_step dimension. + + :param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized). + :param new_policy_values: values predicted by PPOVHead network, + of shape (batch_size) or + of shape (batch_size, time_step). + :param old_policy_values: values predicted by old value network, + of shape (batch_size) or + of shape (batch_size, time_step). + :param target_values: actual state values, + of shape (batch_size) or + of shape (batch_size, time_step). + :return: loss, of shape (batch_size). + """ + # L2 loss + value_loss_1 = (new_policy_values - target_values).square() + # Clipped difference L2 loss + diff = new_policy_values - old_policy_values + clipped_diff = diff.clip(a_min=-self.clip_likelihood_ratio_using_epsilon, + a_max=self.clip_likelihood_ratio_using_epsilon) + value_loss_2 = (old_policy_values + clipped_diff - target_values).square() + # Maximum of the two losses, element-wise maximum. + value_loss_max = mx.nd.stack(value_loss_1, value_loss_2).max(axis=0) + # Aggregate over temporal axis, adding if doesn't exist (hense reshape) + value_loss_max_w_time = value_loss_max.reshape(shape=(0, -1)) + value_loss = value_loss_max_w_time.mean(axis=1) + # Weight the loss (and average over samples of batch) + value_loss_weighted = value_loss.mean() * self.weight + return [(value_loss_weighted, LOSS_OUT_TYPE_LOSS)] + + +class PPOVHead(Head): + def __init__(self, + agent_parameters: AgentParameters, + spaces: SpacesDefinition, + network_name: str, + head_type_idx: int=0, + loss_weight: float=1., + is_local: bool = True, + activation_function: str='relu', + dense_layer: None=None) -> None: + """ + PPO Value Head for predicting state values. + + :param agent_parameters: containing algorithm parameters, but currently unused. + :param spaces: containing action spaces, but currently unused. + :param network_name: name of head network. currently unused. + :param head_type_idx: index of head network. currently unused. + :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param is_local: flag to denote if network is local. currently unused. + :param activation_function: activation function to use between layers. currently unused. + :param dense_layer: type of dense layer to use in network. currently unused. + """ + super(PPOVHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local, + activation_function, dense_layer=dense_layer) + self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon + self.return_type = ActionProbabilities + with self.name_scope(): + self.dense = nn.Dense(units=1) + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type: + """ + Used for forward pass through value head network. + + :param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized). + :param x: middleware state representation, of shape (batch_size, in_channels). + :return: final value output of network, of shape (batch_size). + """ + return self.dense(x).squeeze() + + def loss(self) -> mx.gluon.loss.Loss: + """ + Specifies loss block to be used for specific value head implementation. + + :return: loss block (can be called as function) for outputs returned by the value head network. + """ + return PPOVHeadLoss(self.clip_likelihood_ratio_using_epsilon, weight=self.loss_weight) diff --git a/rl_coach/architectures/mxnet_components/heads/q_head.py b/rl_coach/architectures/mxnet_components/heads/q_head.py new file mode 100644 index 0000000..88e107f --- /dev/null +++ b/rl_coach/architectures/mxnet_components/heads/q_head.py @@ -0,0 +1,106 @@ +from typing import Union, List, Tuple +from types import ModuleType + +import mxnet as mx +from mxnet.gluon.loss import Loss, HuberLoss, L2Loss +from mxnet.gluon import nn +from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema +from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS +from rl_coach.base_parameters import AgentParameters +from rl_coach.core_types import QActionStateValue +from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class QHeadLoss(HeadLoss): + def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None: + """ + Loss for Q-Value Head. + + :param loss_type: loss function with default of mean squared error (i.e. L2Loss). + :param weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation. + """ + super(QHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis) + with self.name_scope(): + self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis) + + @property + def input_schema(self) -> LossInputSchema: + return LossInputSchema( + head_outputs=['pred'], + agent_inputs=[], + targets=['target'] + ) + + def loss_forward(self, + F: ModuleType, + pred: nd_sym_type, + target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]: + """ + Used for forward pass through loss computations. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param pred: state-action q-values predicted by QHead network, of shape (batch_size, num_actions). + :param target: actual state-action q-values, of shape (batch_size, num_actions). + :return: loss, of shape (batch_size). + """ + loss = self.loss_fn(pred, target).mean() + return [(loss, LOSS_OUT_TYPE_LOSS)] + + +class QHead(Head): + def __init__(self, + agent_parameters: AgentParameters, + spaces: SpacesDefinition, + network_name: str, + head_type_idx: int=0, + loss_weight: float=1., + is_local: bool=True, + activation_function: str='relu', + dense_layer: None=None, + loss_type: Union[HuberLoss, L2Loss]=L2Loss) -> None: + """ + Q-Value Head for predicting state-action Q-Values. + + :param agent_parameters: containing algorithm parameters, but currently unused. + :param spaces: containing action spaces used for defining size of network output. + :param network_name: name of head network. currently unused. + :param head_type_idx: index of head network. currently unused. + :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param is_local: flag to denote if network is local. currently unused. + :param activation_function: activation function to use between layers. currently unused. + :param dense_layer: type of dense layer to use in network. currently unused. + :param loss_type: loss function to use. + """ + super(QHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, + is_local, activation_function, dense_layer) + if isinstance(self.spaces.action, BoxActionSpace): + self.num_actions = 1 + elif isinstance(self.spaces.action, DiscreteActionSpace): + self.num_actions = len(self.spaces.action.actions) + self.return_type = QActionStateValue + assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss." + self.loss_type = loss_type + + with self.name_scope(): + self.dense = nn.Dense(units=self.num_actions) + + def loss(self) -> Loss: + """ + Specifies loss block to be used for specific value head implementation. + + :return: loss block (can be called as function) for outputs returned by the head network. + """ + return QHeadLoss(loss_type=self.loss_type, weight=self.loss_weight) + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type: + """ + Used for forward pass through Q-Value head network. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: middleware state representation, of shape (batch_size, in_channels). + :return: predicted state-action q-values, of shape (batch_size, num_actions). + """ + return self.dense(x) diff --git a/rl_coach/architectures/mxnet_components/heads/v_head.py b/rl_coach/architectures/mxnet_components/heads/v_head.py new file mode 100644 index 0000000..a04cafd --- /dev/null +++ b/rl_coach/architectures/mxnet_components/heads/v_head.py @@ -0,0 +1,100 @@ +from typing import Union, List, Tuple +from types import ModuleType + +import mxnet as mx +from mxnet.gluon.loss import Loss, HuberLoss, L2Loss +from mxnet.gluon import nn +from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema +from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS +from rl_coach.base_parameters import AgentParameters +from rl_coach.core_types import VStateValue +from rl_coach.spaces import SpacesDefinition + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class VHeadLoss(HeadLoss): + def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None: + """ + Loss for Value Head. + + :param loss_type: loss function with default of mean squared error (i.e. L2Loss). + :param weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation. + """ + super(VHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis) + with self.name_scope(): + self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis) + + @property + def input_schema(self) -> LossInputSchema: + return LossInputSchema( + head_outputs=['pred'], + agent_inputs=[], + targets=['target'] + ) + + def loss_forward(self, + F: ModuleType, + pred: nd_sym_type, + target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]: + """ + Used for forward pass through loss computations. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param pred: state values predicted by VHead network, of shape (batch_size). + :param target: actual state values, of shape (batch_size). + :return: loss, of shape (batch_size). + """ + loss = self.loss_fn(pred, target).mean() + return [(loss, LOSS_OUT_TYPE_LOSS)] + + +class VHead(Head): + def __init__(self, + agent_parameters: AgentParameters, + spaces: SpacesDefinition, + network_name: str, + head_type_idx: int=0, + loss_weight: float=1., + is_local: bool=True, + activation_function: str='relu', + dense_layer: None=None, + loss_type: Union[HuberLoss, L2Loss]=L2Loss): + """ + Value Head for predicting state values. + :param agent_parameters: containing algorithm parameters, but currently unused. + :param spaces: containing action spaces, but currently unused. + :param network_name: name of head network. currently unused. + :param head_type_idx: index of head network. currently unused. + :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param is_local: flag to denote if network is local. currently unused. + :param activation_function: activation function to use between layers. currently unused. + :param dense_layer: type of dense layer to use in network. currently unused. + :param loss_type: loss function with default of mean squared error (i.e. L2Loss), or alternatively HuberLoss. + """ + super(VHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, + is_local, activation_function, dense_layer) + assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss." + self.loss_type = loss_type + self.return_type = VStateValue + with self.name_scope(): + self.dense = nn.Dense(units=1) + + def loss(self) -> Loss: + """ + Specifies loss block to be used for specific value head implementation. + + :return: loss block (can be called as function) for outputs returned by the head network. + """ + return VHeadLoss(loss_type=self.loss_type, weight=self.loss_weight) + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type: + """ + Used for forward pass through value head network. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: middleware state representation, of shape (batch_size, in_channels). + :return: final output of value network, of shape (batch_size). + """ + return self.dense(x).squeeze() diff --git a/rl_coach/architectures/mxnet_components/layers.py b/rl_coach/architectures/mxnet_components/layers.py new file mode 100644 index 0000000..233d225 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/layers.py @@ -0,0 +1,99 @@ +""" +Module implementing basic layers in mxnet +""" + +from types import FunctionType + +from mxnet.gluon import nn + +from rl_coach.architectures import layers +from rl_coach.architectures.mxnet_components import utils + + +# define global dictionary for storing layer type to layer implementation mapping +mx_layer_dict = dict() + + +def reg_to_mx(layer_type) -> FunctionType: + """ function decorator that registers layer implementation + :return: decorated function + """ + def reg_impl_decorator(func): + assert layer_type not in mx_layer_dict + mx_layer_dict[layer_type] = func + return func + return reg_impl_decorator + + +def convert_layer(layer): + """ + If layer is callable, return layer, otherwise convert to MX type + :param layer: layer to be converted + :return: converted layer if not callable, otherwise layer itself + """ + if callable(layer): + return layer + return mx_layer_dict[type(layer)](layer) + + +class Conv2d(layers.Conv2d): + def __init__(self, num_filters: int, kernel_size: int, strides: int): + super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides) + + def __call__(self) -> nn.Conv2D: + """ + returns a conv2d block + :return: conv2d block + """ + return nn.Conv2D(channels=self.num_filters, kernel_size=self.kernel_size, strides=self.strides) + + @staticmethod + @reg_to_mx(layers.Conv2d) + def to_mx(base: layers.Conv2d): + return Conv2d(num_filters=base.num_filters, kernel_size=base.kernel_size, strides=base.strides) + + +class BatchnormActivationDropout(layers.BatchnormActivationDropout): + def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0): + super(BatchnormActivationDropout, self).__init__( + batchnorm=batchnorm, activation_function=activation_function, dropout_rate=dropout_rate) + + def __call__(self): + """ + returns a list of mxnet batchnorm, activation and dropout layers + :return: batchnorm, activation and dropout layers + """ + block = nn.HybridSequential() + if self.batchnorm: + block.add(nn.BatchNorm()) + if self.activation_function: + block.add(nn.Activation(activation=utils.get_mxnet_activation_name(self.activation_function))) + if self.dropout_rate: + block.add(nn.Dropout(self.dropout_rate)) + return block + + @staticmethod + @reg_to_mx(layers.BatchnormActivationDropout) + def to_mx(base: layers.BatchnormActivationDropout): + return BatchnormActivationDropout( + batchnorm=base.batchnorm, + activation_function=base.activation_function, + dropout_rate=base.dropout_rate) + + +class Dense(layers.Dense): + def __init__(self, units: int): + super(Dense, self).__init__(units=units) + + def __call__(self): + """ + returns a mxnet dense layer + :return: dense layer + """ + # Set flatten to False for consistent behavior with tf.layers.dense + return nn.Dense(self.units, flatten=False) + + @staticmethod + @reg_to_mx(layers.Dense) + def to_mx(base: layers.Dense): + return Dense(units=base.units) diff --git a/rl_coach/architectures/mxnet_components/middlewares/__init__.py b/rl_coach/architectures/mxnet_components/middlewares/__init__.py new file mode 100644 index 0000000..3647dfa --- /dev/null +++ b/rl_coach/architectures/mxnet_components/middlewares/__init__.py @@ -0,0 +1,4 @@ +from .fc_middleware import FCMiddleware +from .lstm_middleware import LSTMMiddleware + +__all__ = ["FCMiddleware", "LSTMMiddleware"] \ No newline at end of file diff --git a/rl_coach/architectures/mxnet_components/middlewares/fc_middleware.py b/rl_coach/architectures/mxnet_components/middlewares/fc_middleware.py new file mode 100644 index 0000000..de3377f --- /dev/null +++ b/rl_coach/architectures/mxnet_components/middlewares/fc_middleware.py @@ -0,0 +1,52 @@ +""" +Module that defines the fully-connected middleware class +""" + +from rl_coach.architectures.mxnet_components.layers import Dense +from rl_coach.architectures.mxnet_components.middlewares.middleware import Middleware +from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters +from rl_coach.base_parameters import MiddlewareScheme + + +class FCMiddleware(Middleware): + def __init__(self, params: FCMiddlewareParameters): + """ + FCMiddleware or Fully-Connected Middleware can be used in the middle part of the network. It takes the + embeddings from the input embedders, after they were aggregated in some method (for example, concatenation) + and passes it through a neural network which can be customizable but shared between the heads of the network. + + :param params: parameters object containing batchnorm, activation_function and dropout properties. + """ + super(FCMiddleware, self).__init__(params) + + @property + def schemes(self) -> dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used for the + Middleware. Are used to create Block when FCMiddleware is initialised. + + :return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of mxnet.gluon.Block. + """ + return { + MiddlewareScheme.Empty: + [], + + # Use for PPO + MiddlewareScheme.Shallow: + [ + Dense(units=64) + ], + + # Use for DQN + MiddlewareScheme.Medium: + [ + Dense(units=512) + ], + + MiddlewareScheme.Deep: + [ + Dense(units=128), + Dense(units=128), + Dense(units=128) + ] + } diff --git a/rl_coach/architectures/mxnet_components/middlewares/lstm_middleware.py b/rl_coach/architectures/mxnet_components/middlewares/lstm_middleware.py new file mode 100644 index 0000000..b8316d4 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/middlewares/lstm_middleware.py @@ -0,0 +1,80 @@ +""" +Module that defines the LSTM middleware class +""" + +from typing import Union +from types import ModuleType + +import mxnet as mx +from mxnet.gluon import rnn +from rl_coach.architectures.mxnet_components.layers import Dense +from rl_coach.architectures.mxnet_components.middlewares.middleware import Middleware +from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters +from rl_coach.base_parameters import MiddlewareScheme + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class LSTMMiddleware(Middleware): + def __init__(self, params: LSTMMiddlewareParameters): + """ + LSTMMiddleware or Long Short Term Memory Middleware can be used in the middle part of the network. It takes the + embeddings from the input embedders, after they were aggregated in some method (for example, concatenation) + and passes it through a neural network which can be customizable but shared between the heads of the network. + + :param params: parameters object containing batchnorm, activation_function, dropout and + number_of_lstm_cells properties. + """ + super(LSTMMiddleware, self).__init__(params) + self.number_of_lstm_cells = params.number_of_lstm_cells + with self.name_scope(): + self.lstm = rnn.LSTM(hidden_size=self.number_of_lstm_cells) + + @property + def schemes(self) -> dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used for the + Middleware. Are used to create Block when LSTMMiddleware is initialised, and are applied before the LSTM. + + :return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of mxnet.gluon.Block. + """ + return { + MiddlewareScheme.Empty: + [], + + # Use for PPO + MiddlewareScheme.Shallow: + [ + Dense(units=64) + ], + + # Use for DQN + MiddlewareScheme.Medium: + [ + Dense(units=512) + ], + + MiddlewareScheme.Deep: + [ + Dense(units=128), + Dense(units=128), + Dense(units=128) + ] + } + + def hybrid_forward(self, + F: ModuleType, + x: nd_sym_type, + *args, **kwargs) -> nd_sym_type: + """ + Used for forward pass through LSTM middleware network. + Applies dense layers from selected scheme before passing result to LSTM layer. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: state embedding, of shape (batch_size, in_channels). + :return: state middleware embedding, where shape is (batch_size, channels). + """ + x_ntc = x.reshape(shape=(0, 0, -1)) + emb_ntc = super(LSTMMiddleware, self).hybrid_forward(F, x_ntc, *args, **kwargs) + emb_tnc = emb_ntc.transpose(axes=(1, 0, 2)) + return self.lstm(emb_tnc) diff --git a/rl_coach/architectures/mxnet_components/middlewares/middleware.py b/rl_coach/architectures/mxnet_components/middlewares/middleware.py new file mode 100644 index 0000000..8b9db01 --- /dev/null +++ b/rl_coach/architectures/mxnet_components/middlewares/middleware.py @@ -0,0 +1,61 @@ +from typing import Union +from types import ModuleType + +import mxnet as mx +from mxnet.gluon import nn +from rl_coach.architectures.middleware_parameters import MiddlewareParameters +from rl_coach.architectures.mxnet_components.layers import convert_layer +from rl_coach.base_parameters import MiddlewareScheme + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class Middleware(nn.HybridBlock): + def __init__(self, params: MiddlewareParameters): + """ + Middleware is the middle part of the network. It takes the embeddings from the input embedders, + after they were aggregated in some method (for example, concatenation) and passes it through a neural network + which can be customizable but shared between the heads of the network. + + :param params: parameters object containing batchnorm, activation_function and dropout properties. + """ + super(Middleware, self).__init__() + self.scheme = params.scheme + + with self.name_scope(): + self.net = nn.HybridSequential() + if isinstance(self.scheme, MiddlewareScheme): + blocks = self.schemes[self.scheme] + else: + # if scheme is specified directly, convert to MX layer if it's not a callable object + # NOTE: if layer object is callable, it must return a gluon block when invoked + blocks = [convert_layer(l) for l in self.scheme] + for block in blocks: + self.net.add(block()) + if params.batchnorm: + self.net.add(nn.BatchNorm()) + if params.activation_function: + self.net.add(nn.Activation(params.activation_function)) + if params.dropout: + self.net.add(nn.Dropout(rate=params.dropout)) + + @property + def schemes(self) -> dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used for the + Middleware. Should be implemented in child classes, and are used to create Block when Middleware is initialised. + + :return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of mxnet.gluon.Block. + """ + raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default " + "configurations.") + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type: + """ + Used for forward pass through middleware network. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: state embedding, of shape (batch_size, in_channels). + :return: state middleware embedding, where shape is (batch_size, channels). + """ + return self.net(x) diff --git a/rl_coach/architectures/mxnet_components/utils.py b/rl_coach/architectures/mxnet_components/utils.py new file mode 100644 index 0000000..5f1659c --- /dev/null +++ b/rl_coach/architectures/mxnet_components/utils.py @@ -0,0 +1,280 @@ +""" +Module defining utility functions +""" +import inspect +from typing import Any, Dict, Generator, Iterable, List, Tuple, Union +from types import ModuleType + +import mxnet as mx +from mxnet import nd +from mxnet.ndarray import NDArray +import numpy as np + +from rl_coach.core_types import GradientClippingMethod + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\ + Union[List[NDArray], Tuple[NDArray], NDArray]: + """ + Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or + it can be np.ndarray, int, or float + :param data: input data to be converted + :return: converted output data + """ + if isinstance(data, list): + data = [to_mx_ndarray(d) for d in data] + elif isinstance(data, tuple): + data = tuple(to_mx_ndarray(d) for d in data) + elif isinstance(data, np.ndarray): + data = nd.array(data) + elif isinstance(data, NDArray): + pass + elif isinstance(data, int) or isinstance(data, float): + data = nd.array([data]) + else: + raise TypeError('Unsupported data type: {}'.format(type(data))) + return data + + +def asnumpy_or_asscalar(data: Union[NDArray, list, tuple]) -> Union[np.ndarray, np.number, list, tuple]: + """ + Convert NDArray (or list or tuple of NDArray) to numpy. If shape is (1,), then convert to scalar instead. + NOTE: This behavior is consistent with tensorflow + :param data: NDArray or list or tuple of NDArray + :return: data converted to numpy ndarray or to numpy scalar + """ + if isinstance(data, list): + data = [asnumpy_or_asscalar(d) for d in data] + elif isinstance(data, tuple): + data = tuple(asnumpy_or_asscalar(d) for d in data) + elif isinstance(data, NDArray): + data = data.asscalar() if data.shape == (1,) else data.asnumpy() + else: + raise TypeError('Unsupported data type: {}'.format(type(data))) + return data + + +def global_norm(arrays: Union[Generator[NDArray, NDArray, NDArray], List[NDArray], Tuple[NDArray]]) -> NDArray: + """ + Calculate global norm on list or tuple of NDArrays using this formula: + `global_norm = sqrt(sum([l2norm(p)**2 for p in parameters]))` + + :param arrays: list or tuple of parameters to calculate global norm on + :return: single-value NDArray + """ + def _norm(array): + if array.stype == 'default': + x = array.reshape((-1,)) + return nd.dot(x, x) + return array.norm().square() + + total_norm = nd.add_n(*[_norm(arr) for arr in arrays]) + total_norm = nd.sqrt(total_norm) + return total_norm + + +def split_outputs_per_head(outputs: Tuple[NDArray], heads: list) -> List[List[NDArray]]: + """ + Split outputs into outputs per head + :param outputs: list of all outputs + :param heads: list of all heads + :return: list of outputs for each head + """ + head_outputs = [] + for h in heads: + head_outputs.append(list(outputs[:h.num_outputs])) + outputs = outputs[h.num_outputs:] + assert len(outputs) == 0 + return head_outputs + + +def split_targets_per_loss(targets: list, losses: list) -> List[list]: + """ + Splits targets into targets per loss + :param targets: list of all targets (typically numpy ndarray) + :param losses: list of all losses + :return: list of targets for each loss + """ + loss_targets = list() + for l in losses: + loss_data_len = len(l.input_schema.targets) + assert len(targets) >= loss_data_len, "Data length doesn't match schema" + loss_targets.append(targets[:loss_data_len]) + targets = targets[loss_data_len:] + assert len(targets) == 0 + return loss_targets + + +def get_loss_agent_inputs(inputs: Dict[str, np.ndarray], head_type_idx: int, loss: Any) -> List[np.ndarray]: + """ + Collects all inputs with prefix 'output__' and matches them against agent_inputs in loss input schema. + :param inputs: list of all agent inputs + :param head_type_idx: head-type index of the corresponding head + :param loss: corresponding loss + :return: list of agent inputs for this loss. This list matches the length in loss input schema. + """ + loss_inputs = list() + for k in sorted(inputs.keys()): + if k.startswith('output_{}_'.format(head_type_idx)): + loss_inputs.append(inputs[k]) + # Enforce that number of inputs for head_type are the same as agent_inputs specified by loss input_schema + assert len(loss_inputs) == len(loss.input_schema.agent_inputs), "agent_input length doesn't match schema" + return loss_inputs + + +def align_loss_args( + head_outputs: List[NDArray], + agent_inputs: List[np.ndarray], + targets: List[np.ndarray], + loss: Any) -> List[np.ndarray]: + """ + Creates a list of arguments from head_outputs, agent_inputs, and targets aligned with parameters of + loss.loss_forward() based on their name in loss input_schema + :param head_outputs: list of all head_outputs for this loss + :param agent_inputs: list of all agent_inputs for this loss + :param targets: list of all targets for this loss + :param loss: corresponding loss + :return: list of arguments in correct order to be passed to loss + """ + arg_list = list() + schema = loss.input_schema + assert len(schema.head_outputs) == len(head_outputs) + assert len(schema.agent_inputs) == len(agent_inputs) + assert len(schema.targets) == len(targets) + + prev_found = True + for arg_name in inspect.getfullargspec(loss.loss_forward).args[2:]: # First two args are self and F + found = False + for schema_list, data in [(schema.head_outputs, head_outputs), + (schema.agent_inputs, agent_inputs), + (schema.targets, targets)]: + try: + arg_list.append(data[schema_list.index(arg_name)]) + found = True + break + except ValueError: + continue + assert not found or prev_found, "missing arguments detected!" + prev_found = found + return arg_list + + +def to_tuple(data: Union[tuple, list, Any]): + """ + If input is list, it is converted to tuple. If it's tuple, it is returned untouched. Otherwise + returns a single-element tuple of the data. + :return: tuple-ified data + """ + if isinstance(data, tuple): + pass + elif isinstance(data, list): + data = tuple(data) + else: + data = (data,) + return data + + +def to_list(data: Union[tuple, list, Any]): + """ + If input is tuple, it is converted to list. If it's list, it is returned untouched. Otherwise + returns a single-element list of the data. + :return: list-ified data + """ + if isinstance(data, list): + pass + elif isinstance(data, tuple): + data = list(data) + else: + data = [data] + return data + + +def loss_output_dict(output: List[NDArray], schema: List[str]) -> Dict[str, List[NDArray]]: + """ + Creates a dictionary for loss output based on the output schema. If two output values have the same + type string in the schema they are concatenated in the same dicrionary item. + :param output: list of output values + :param schema: list of type-strings for output values + :return: dictionary of keyword to list of NDArrays + """ + assert len(output) == len(schema) + output_dict = dict() + for name, val in zip(schema, output): + if name in output_dict: + output_dict[name].append(val) + else: + output_dict[name] = [val] + return output_dict + + +def clip_grad( + grads: Union[Generator[NDArray, NDArray, NDArray], List[NDArray], Tuple[NDArray]], + clip_method: GradientClippingMethod, + clip_val: float, + inplace=True) -> List[NDArray]: + """ + Clip gradient values inplace + :param grads: gradients to be clipped + :param clip_method: clipping method + :param clip_val: clipping value. Interpreted differently depending on clipping method. + :param inplace: modify grads if True, otherwise create NDArrays + :return: clipped gradients + """ + output = list(grads) if inplace else list(nd.empty(g.shape) for g in grads) + if clip_method == GradientClippingMethod.ClipByGlobalNorm: + norm_unclipped_grads = global_norm(grads) + scale = clip_val / (norm_unclipped_grads.asscalar() + 1e-8) # todo: use branching operators? + if scale < 1.0: + for g, o in zip(grads, output): + nd.broadcast_mul(g, nd.array([scale]), out=o) + elif clip_method == GradientClippingMethod.ClipByValue: + for g, o in zip(grads, output): + g.clip(-clip_val, clip_val, out=o) + elif clip_method == GradientClippingMethod.ClipByNorm: + for g, o in zip(grads, output): + nd.broadcast_mul(g, nd.minimum(1.0, clip_val / (g.norm() + 1e-8)), out=o) + else: + raise KeyError('Unsupported gradient clipping method') + return output + + +def hybrid_clip(F: ModuleType, x: nd_sym_type, clip_lower: nd_sym_type, clip_upper: nd_sym_type) -> nd_sym_type: + """ + Apply clipping to input x between clip_lower and clip_upper. + Added because F.clip doesn't support clipping bounds that are mx.nd.NDArray or mx.sym.Symbol. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: input data + :param clip_lower: lower bound used for clipping, should be of shape (1,) + :param clip_upper: upper bound used for clipping, should be of shape (1,) + :return: clipped data + """ + x_clip_lower = clip_lower.broadcast_like(x) + x_clip_upper = clip_upper.broadcast_like(x) + x_clipped = F.stack(x, x_clip_lower, axis=0).max(axis=0) + x_clipped = F.stack(x_clipped, x_clip_upper, axis=0).min(axis=0) + return x_clipped + + +def get_mxnet_activation_name(activation_name: str): + """ + Convert coach activation name to mxnet specific activation name + :param activation_name: name of the activation inc coach + :return: name of the activation in mxnet + """ + activation_functions = { + 'relu': 'relu', + 'tanh': 'tanh', + 'sigmoid': 'sigmoid', + # FIXME Add other activations + # 'elu': tf.nn.elu, + 'selu': 'softrelu', + # 'leaky_relu': tf.nn.leaky_relu, + 'none': None + } + assert activation_name in activation_functions, \ + "Activation function must be one of the following {}. instead it was: {}".format( + activation_functions.keys(), activation_name) + return activation_functions[activation_name] diff --git a/rl_coach/architectures/network_wrapper.py b/rl_coach/architectures/network_wrapper.py index 042c93c..b9decee 100644 --- a/rl_coach/architectures/network_wrapper.py +++ b/rl_coach/architectures/network_wrapper.py @@ -183,16 +183,7 @@ class NetworkWrapper(object): target_network or global_network) and the second element is the inputs :return: the outputs of all the networks in the same order as the inputs were given """ - feed_dict = {} - fetches = [] - - for idx, (network, input) in enumerate(network_input_tuples): - feed_dict.update(network.create_feed_dict(input)) - fetches += network.outputs - - outputs = self.sess.run(fetches, feed_dict) - - return outputs + return type(self.online_network).parallel_predict(self.sess, network_input_tuples) def get_local_variables(self): """ diff --git a/rl_coach/architectures/tensorflow_components/architecture.py b/rl_coach/architectures/tensorflow_components/architecture.py index 3353b52..8e7deae 100644 --- a/rl_coach/architectures/tensorflow_components/architecture.py +++ b/rl_coach/architectures/tensorflow_components/architecture.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import time import os +import time +from typing import Any, List, Tuple, Dict import numpy as np import tensorflow as tf @@ -544,6 +545,26 @@ class TensorFlowArchitecture(Architecture): output = squeeze_list(output) return output + @staticmethod + def parallel_predict(sess: Any, + network_input_tuples: List[Tuple['TensorFlowArchitecture', Dict[str, np.ndarray]]]) ->\ + List[np.ndarray]: + """ + :param sess: active session to use for prediction + :param network_input_tuples: tuple of network and corresponding input + :return: list of outputs from all networks + """ + feed_dict = {} + fetches = [] + + for network, input in network_input_tuples: + feed_dict.update(network.create_feed_dict(input)) + fetches += network.outputs + + outputs = sess.run(fetches, feed_dict) + + return outputs + def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None, importance_weights=None): """ Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 422f724..778e4ed 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -22,7 +22,7 @@ from distutils.dir_util import copy_tree, remove_tree from typing import List, Tuple import contextlib -from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, \ +from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \ VisualizationParameters, \ Parameters, PresetValidationParameters from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \ @@ -161,7 +161,8 @@ class GraphManager(object): """ raise NotImplementedError("") - def create_worker_or_parameters_server(self, task_parameters: DistributedTaskParameters): + @staticmethod + def _create_worker_or_parameters_server_tf(task_parameters: DistributedTaskParameters): import tensorflow as tf config = tf.ConfigProto() config.allow_soft_placement = True # allow placing ops on cpu if they are not fit for gpu @@ -170,7 +171,8 @@ class GraphManager(object): config.intra_op_parallelism_threads = 1 config.inter_op_parallelism_threads = 1 - from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_and_start_parameters_server, \ + from rl_coach.architectures.tensorflow_components.distributed_tf_utils import \ + create_and_start_parameters_server, \ create_cluster_spec, create_worker_server_and_device # create cluster spec @@ -190,7 +192,16 @@ class GraphManager(object): raise ValueError("The job type should be either ps or worker and not {}" .format(task_parameters.job_type)) - def create_session(self, task_parameters: DistributedTaskParameters): + @staticmethod + def create_worker_or_parameters_server(task_parameters: DistributedTaskParameters): + if task_parameters.framework_type == Frameworks.tensorflow: + GraphManager._create_worker_or_parameters_server_tf(task_parameters) + elif task_parameters.framework_type == Frameworks.mxnet: + raise NotImplementedError('Distributed training not implemented for MXNet') + else: + raise ValueError('Invalid framework {}'.format(task_parameters.framework_type)) + + def _create_session_tf(self, task_parameters: TaskParameters): import tensorflow as tf config = tf.ConfigProto() config.allow_soft_placement = True # allow placing ops on cpu if they are not fit for gpu @@ -235,6 +246,15 @@ class GraphManager(object): if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir: self.save_graph() + def create_session(self, task_parameters: TaskParameters): + if task_parameters.framework_type == Frameworks.tensorflow: + self._create_session_tf(task_parameters) + elif task_parameters.framework_type == Frameworks.mxnet: + self.set_session(sess=None) # Initialize all modules + # TODO add checkpoint loading + else: + raise ValueError('Invalid framework {}'.format(task_parameters.framework_type)) + def save_graph(self) -> None: """ Save the TF graph to a protobuf description file in the experiment directory @@ -490,27 +510,35 @@ class GraphManager(object): self.train_and_act(self.steps_between_evaluation_periods) self.evaluate(self.evaluation_steps) + def _restore_checkpoint_tf(self, checkpoint_dir: str): + import tensorflow as tf + checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) + screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) + variables = {} + for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): + # Load the variable + var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) + + # Set the new name + new_name = var_name + new_name = new_name.replace('global/', 'online/') + variables[new_name] = var + + for v in self.variables_to_restore: + self.sess.run(v.assign(variables[v.name.split(':')[0]])) + def restore_checkpoint(self): self.verify_graph_was_created() # TODO: find better way to load checkpoints that were saved with a global network into the online network if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir: - import tensorflow as tf - checkpoint_dir = self.task_parameters.checkpoint_restore_dir - checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) - screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) - variables = {} - for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): - # Load the variable - var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) - - # Set the new name - new_name = var_name - new_name = new_name.replace('global/', 'online/') - variables[new_name] = var - - for v in self.variables_to_restore: - self.sess.run(v.assign(variables[v.name.split(':')[0]])) + if self.task_parameters.framework_type == Frameworks.tensorflow: + self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir) + elif self.task_parameters.framework_type == Frameworks.mxnet: + # TODO implement checkpoint restore + pass + else: + raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type)) def occasionally_save_checkpoint(self): # only the chief process saves checkpoints @@ -529,7 +557,10 @@ class GraphManager(object): self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])) if not isinstance(self.task_parameters, DistributedTaskParameters): - saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) + if self.checkpoint_saver is not None: + saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) + else: + saved_checkpoint_path = "" else: saved_checkpoint_path = checkpoint_path diff --git a/rl_coach/presets/CartPole_PPO.py b/rl_coach/presets/CartPole_PPO.py index e8af4c1..0c13abb 100644 --- a/rl_coach/presets/CartPole_PPO.py +++ b/rl_coach/presets/CartPole_PPO.py @@ -49,8 +49,8 @@ agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoach agent_params.exploration = EGreedyParameters() agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) -agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation', - ObservationNormalizationFilter(name='normalize_observation')) +# agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation', +# ObservationNormalizationFilter(name='normalize_observation')) ############### # Environment # diff --git a/rl_coach/tests/architectures/mxnet_components/__init__.py b/rl_coach/tests/architectures/mxnet_components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl_coach/tests/architectures/mxnet_components/embedders/__init__.py b/rl_coach/tests/architectures/mxnet_components/embedders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl_coach/tests/architectures/mxnet_components/embedders/test_image_embedder.py b/rl_coach/tests/architectures/mxnet_components/embedders/test_image_embedder.py new file mode 100644 index 0000000..0e7f9da --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/embedders/test_image_embedder.py @@ -0,0 +1,21 @@ +import mxnet as mx +import os +import pytest +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +from rl_coach.base_parameters import EmbedderScheme +from rl_coach.architectures.embedder_parameters import InputEmbedderParameters +from rl_coach.architectures.mxnet_components.embedders.image_embedder import ImageEmbedder + + +@pytest.mark.unit_test +def test_image_embedder(): + params = InputEmbedderParameters(scheme=EmbedderScheme.Medium) + emb = ImageEmbedder(params=params) + emb.initialize() + input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 3, 244, 244)) + output = emb(input_data) + assert len(output.shape) == 2 # since last block was flatten + assert output.shape[0] == 10 # since batch_size is 10 diff --git a/rl_coach/tests/architectures/mxnet_components/embedders/test_vector_embedder.py b/rl_coach/tests/architectures/mxnet_components/embedders/test_vector_embedder.py new file mode 100644 index 0000000..8d51840 --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/embedders/test_vector_embedder.py @@ -0,0 +1,22 @@ +import mxnet as mx +import os +import pytest +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +from rl_coach.architectures.embedder_parameters import InputEmbedderParameters +from rl_coach.architectures.mxnet_components.embedders.vector_embedder import VectorEmbedder +from rl_coach.base_parameters import EmbedderScheme + + +@pytest.mark.unit_test +def test_vector_embedder(): + params = InputEmbedderParameters(scheme=EmbedderScheme.Medium) + emb = VectorEmbedder(params=params) + emb.initialize() + input_data = mx.nd.random.uniform(low=0, high=255, shape=(10, 100)) + output = emb(input_data) + assert len(output.shape) == 2 # since last block was flatten + assert output.shape[0] == 10 # since batch_size is 10 + assert output.shape[1] == 256 # since last dense layer has 256 units diff --git a/rl_coach/tests/architectures/mxnet_components/heads/__init__.py b/rl_coach/tests/architectures/mxnet_components/heads/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl_coach/tests/architectures/mxnet_components/heads/test_ppo_head.py b/rl_coach/tests/architectures/mxnet_components/heads/test_ppo_head.py new file mode 100644 index 0000000..9814925 --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/heads/test_ppo_head.py @@ -0,0 +1,406 @@ +import mxnet as mx +import numpy as np +import os +import pytest +from scipy import stats as sp_stats +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +from rl_coach.architectures.head_parameters import PPOHeadParameters +from rl_coach.architectures.mxnet_components.heads.ppo_head import CategoricalDist, MultivariateNormalDist,\ + DiscretePPOHead, ClippedPPOLossDiscrete, ClippedPPOLossContinuous, PPOHead +from rl_coach.agents.clipped_ppo_agent import ClippedPPOAlgorithmParameters, ClippedPPOAgentParameters +from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace + + +@pytest.mark.unit_test +def test_multivariate_normal_dist_shape(): + num_var = 2 + means = mx.nd.array((0, 1)) + covar = mx.nd.array(((1, 0),(0, 0.5))) + data = mx.nd.array((0.5, 0.8)) + policy_dist = MultivariateNormalDist(num_var, means, covar) + log_probs = policy_dist.log_prob(data) + assert log_probs.ndim == 1 + assert log_probs.shape[0] == 1 + + +@pytest.mark.unit_test +def test_multivariate_normal_dist_batch_shape(): + num_var = 2 + batch_size = 3 + means = mx.nd.random.uniform(shape=(batch_size, num_var)) + # create batch of covariance matrices only defined on diagonal + std = mx.nd.array((1, 0.5)).broadcast_like(means).expand_dims(-2) + covar = mx.nd.eye(N=num_var) * std + data = mx.nd.random.uniform(shape=(batch_size, num_var)) + policy_dist = MultivariateNormalDist(num_var, means, covar) + log_probs = policy_dist.log_prob(data) + assert log_probs.ndim == 1 + assert log_probs.shape[0] == batch_size + + +@pytest.mark.unit_test +def test_multivariate_normal_dist_batch_time_shape(): + num_var = 2 + batch_size = 3 + time_steps = 4 + means = mx.nd.random.uniform(shape=(batch_size, time_steps, num_var)) + # create batch (per time step) of covariance matrices only defined on diagonal + std = mx.nd.array((1, 0.5)).broadcast_like(means).expand_dims(-2) + covar = mx.nd.eye(N=num_var) * std + data = mx.nd.random.uniform(shape=(batch_size, time_steps, num_var)) + policy_dist = MultivariateNormalDist(num_var, means, covar) + log_probs = policy_dist.log_prob(data) + assert log_probs.ndim == 2 + assert log_probs.shape[0] == batch_size + assert log_probs.shape[1] == time_steps + + +@pytest.mark.unit_test +def test_multivariate_normal_dist_kl_div(): + n_classes = 2 + dist_a = MultivariateNormalDist(num_var=n_classes, + mean = mx.nd.array([0.2, 0.8]).expand_dims(0), + sigma = mx.nd.array([[1, 0.5], [0.5, 0.5]]).expand_dims(0)) + dist_b = MultivariateNormalDist(num_var=n_classes, + mean = mx.nd.array([0.3, 0.7]).expand_dims(0), + sigma = mx.nd.array([[1, 0.2], [0.2, 0.5]]).expand_dims(0)) + + actual = dist_a.kl_div(dist_b).asnumpy() + np.testing.assert_almost_equal(actual=actual, desired=0.195100128) + + +@pytest.mark.unit_test +def test_multivariate_normal_dist_kl_div_batch(): + n_classes = 2 + dist_a = MultivariateNormalDist(num_var=n_classes, + mean = mx.nd.array([[0.2, 0.8], + [0.2, 0.8]]), + sigma = mx.nd.array([[[1, 0.5], [0.5, 0.5]], + [[1, 0.5], [0.5, 0.5]]])) + dist_b = MultivariateNormalDist(num_var=n_classes, + mean = mx.nd.array([[0.3, 0.7], + [0.3, 0.7]]), + sigma = mx.nd.array([[[1, 0.2], [0.2, 0.5]], + [[1, 0.2], [0.2, 0.5]]])) + + actual = dist_a.kl_div(dist_b).asnumpy() + np.testing.assert_almost_equal(actual=actual, desired=[0.195100128, 0.195100128]) + + +@pytest.mark.unit_test +def test_categorical_dist_shape(): + num_actions = 2 + # actions taken, of shape (batch_size, time_steps) + actions = mx.nd.array((1,)) + # action probabilities, of shape (batch_size, time_steps, num_actions) + policy_probs = mx.nd.array((0.8, 0.2)) + policy_dist = CategoricalDist(num_actions, policy_probs) + action_probs = policy_dist.log_prob(actions) + assert action_probs.ndim == 1 + assert action_probs.shape[0] == 1 + + +@pytest.mark.unit_test +def test_categorical_dist_batch_shape(): + batch_size = 3 + num_actions = 2 + # actions taken, of shape (batch_size, time_steps) + actions = mx.nd.array((0, 1, 0)) + # action probabilities, of shape (batch_size, time_steps, num_actions) + policy_probs = mx.nd.array(((0.8, 0.2), (0.5, 0.5), (0.5, 0.5))) + policy_dist = CategoricalDist(num_actions, policy_probs) + action_probs = policy_dist.log_prob(actions) + assert action_probs.ndim == 1 + assert action_probs.shape[0] == batch_size + + +@pytest.mark.unit_test +def test_categorical_dist_batch_time_shape(): + batch_size = 3 + time_steps = 4 + num_actions = 2 + # actions taken, of shape (batch_size, time_steps) + actions = mx.nd.array(((0, 1, 0, 0), + (1, 1, 0, 0), + (0, 0, 0, 0))) + # action probabilities, of shape (batch_size, time_steps, num_actions) + policy_probs = mx.nd.array((((0.8, 0.2), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)), + ((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)), + ((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)))) + policy_dist = CategoricalDist(num_actions, policy_probs) + action_probs = policy_dist.log_prob(actions) + assert action_probs.ndim == 2 + assert action_probs.shape[0] == batch_size + assert action_probs.shape[1] == time_steps + + +@pytest.mark.unit_test +def test_categorical_dist_batch(): + n_classes = 2 + probs = mx.nd.array(((0.8, 0.2), + (0.7, 0.3), + (0.5, 0.5))) + + dist = CategoricalDist(n_classes, probs) + # check log_prob + actions = mx.nd.array((0, 1, 0)) + manual_log_prob = np.array((-0.22314353, -1.20397282, -0.69314718)) + np.testing.assert_almost_equal(actual=dist.log_prob(actions).asnumpy(), desired=manual_log_prob) + # check entropy + sp_entropy = np.array([sp_stats.entropy(pk=(0.8, 0.2)), + sp_stats.entropy(pk=(0.7, 0.3)), + sp_stats.entropy(pk=(0.5, 0.5))]) + np.testing.assert_almost_equal(actual=dist.entropy().asnumpy(), desired=sp_entropy) + + +@pytest.mark.unit_test +def test_categorical_dist_kl_div(): + n_classes = 3 + dist_a = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.4, 0.2, 0.4])) + dist_b = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.3, 0.4, 0.3])) + dist_c = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.2, 0.6, 0.2])) + dist_d = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.0, 1.0, 0.0])) + np.testing.assert_almost_equal(actual=dist_a.kl_div(dist_b).asnumpy(), desired=0.09151624) + np.testing.assert_almost_equal(actual=dist_a.kl_div(dist_c).asnumpy(), desired=0.33479536) + np.testing.assert_almost_equal(actual=dist_c.kl_div(dist_a).asnumpy(), desired=0.38190854) + np.testing.assert_almost_equal(actual=dist_a.kl_div(dist_d).asnumpy(), desired=np.nan) + np.testing.assert_almost_equal(actual=dist_d.kl_div(dist_a).asnumpy(), desired=1.60943782) + + +@pytest.mark.unit_test +def test_categorical_dist_kl_div_batch(): + n_classes = 3 + dist_a = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([[0.4, 0.2, 0.4], + [0.4, 0.2, 0.4], + [0.4, 0.2, 0.4]])) + dist_b = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([[0.3, 0.4, 0.3], + [0.3, 0.4, 0.3], + [0.3, 0.4, 0.3]])) + actual = dist_a.kl_div(dist_b).asnumpy() + np.testing.assert_almost_equal(actual=actual, desired=[0.09151624, 0.09151624, 0.09151624]) + + +@pytest.mark.unit_test +def test_clipped_ppo_loss_continuous_batch(): + # check lower loss for policy with better probabilities: + # i.e. higher probability on high advantage actions, low probability on low advantage actions. + loss_fn = ClippedPPOLossContinuous(num_actions=2, + clip_likelihood_ratio_using_epsilon=0.2) + loss_fn.initialize() + # actual actions taken, of shape (batch_size) + actions = mx.nd.array(((0.5, -0.5), (0.2, 0.3), (0.4, 2.0))) + # advantages from taking action, of shape (batch_size) + advantages = mx.nd.array((2, -2, 1)) + # action probabilities, of shape (batch_size, num_actions) + old_policy_means = mx.nd.array(((1, 0), (0, 0), (-1, 0))) + new_policy_means_worse = mx.nd.array(((2, 0), (0, 0), (-1, 0))) + new_policy_means_better = mx.nd.array(((0.5, 0), (0, 0), (-1, 0))) + + policy_stds = mx.nd.array(((1, 1), (1, 1), (1, 1))) + clip_param_rescaler = mx.nd.array((1,)) + + loss_worse = loss_fn(new_policy_means_worse, policy_stds, + actions, old_policy_means, policy_stds, + clip_param_rescaler, advantages) + loss_better = loss_fn(new_policy_means_better, policy_stds, + actions, old_policy_means, policy_stds, + clip_param_rescaler, advantages) + + assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO) + loss_worse_val = loss_worse[0] + assert loss_worse_val.ndim == 1 + assert loss_worse_val.shape[0] == 1 + assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO) + loss_better_val = loss_better[0] + assert loss_better_val.ndim == 1 + assert loss_better_val.shape[0] == 1 + assert loss_worse_val > loss_better_val + + +@pytest.mark.unit_test +def test_clipped_ppo_loss_discrete_batch(): + # check lower loss for policy with better probabilities: + # i.e. higher probability on high advantage actions, low probability on low advantage actions. + loss_fn = ClippedPPOLossDiscrete(num_actions=2, + clip_likelihood_ratio_using_epsilon=None, + use_kl_regularization=True, + initial_kl_coefficient=1) + loss_fn.initialize() + + # actual actions taken, of shape (batch_size) + actions = mx.nd.array((0, 1, 0)) + # advantages from taking action, of shape (batch_size) + advantages = mx.nd.array((-2, 2, 1)) + # action probabilities, of shape (batch_size, num_actions) + old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6))) + new_policy_probs_worse = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6))) + new_policy_probs_better = mx.nd.array(((0.5, 0.5), (0.2, 0.8), (0.4, 0.6))) + + clip_param_rescaler = mx.nd.array((1,)) + + loss_worse = loss_fn(new_policy_probs_worse, actions, old_policy_probs, clip_param_rescaler, advantages) + loss_better = loss_fn(new_policy_probs_better, actions, old_policy_probs, clip_param_rescaler, advantages) + + assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO) + lw_loss, lw_reg, lw_kl, lw_ent, lw_lr, lw_clip_lr = loss_worse + assert lw_loss.ndim == 1 + assert lw_loss.shape[0] == 1 + assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO) + lb_loss, lb_reg, lb_kl, lb_ent, lb_lr, lb_clip_lr = loss_better + assert lb_loss.ndim == 1 + assert lb_loss.shape[0] == 1 + assert lw_loss > lb_loss + assert lw_kl > lb_kl + + +@pytest.mark.unit_test +def test_clipped_ppo_loss_discrete_batch_kl_div(): + # check lower loss for policy with better probabilities: + # i.e. higher probability on high advantage actions, low probability on low advantage actions. + loss_fn = ClippedPPOLossDiscrete(num_actions=2, + clip_likelihood_ratio_using_epsilon=None, + use_kl_regularization=True, + initial_kl_coefficient=0.5) + loss_fn.initialize() + + # actual actions taken, of shape (batch_size) + actions = mx.nd.array((0, 1, 0)) + # advantages from taking action, of shape (batch_size) + advantages = mx.nd.array((-2, 2, 1)) + # action probabilities, of shape (batch_size, num_actions) + old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6))) + new_policy_probs_worse = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6))) + new_policy_probs_better = mx.nd.array(((0.5, 0.5), (0.2, 0.8), (0.4, 0.6))) + + clip_param_rescaler = mx.nd.array((1,)) + + loss_worse = loss_fn(new_policy_probs_worse, actions, old_policy_probs, clip_param_rescaler, advantages) + loss_better = loss_fn(new_policy_probs_better, actions, old_policy_probs, clip_param_rescaler, advantages) + + assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO) + lw_loss, lw_reg, lw_kl, lw_ent, lw_lr, lw_clip_lr = loss_worse + assert lw_kl.ndim == 1 + assert lw_kl.shape[0] == 1 + assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO) + lb_loss, lb_reg, lb_kl, lb_ent, lb_lr, lb_clip_lr = loss_better + assert lb_kl.ndim == 1 + assert lb_kl.shape[0] == 1 + assert lw_kl > lb_kl + assert lw_reg > lb_reg + + +@pytest.mark.unit_test +def test_clipped_ppo_loss_discrete_batch_time(): + batch_size = 3 + time_steps = 4 + num_actions = 2 + + # actions taken, of shape (batch_size, time_steps) + actions = mx.nd.array(((0, 1, 0, 0), + (1, 1, 0, 0), + (0, 0, 0, 0))) + # advantages from taking action, of shape (batch_size, time_steps) + advantages = mx.nd.array(((-2, 2, 1, 0), + (-1, 1, 0, 1), + (-1, 0, 1, 0))) + # action probabilities, of shape (batch_size, num_actions) + old_policy_probs = mx.nd.array((((0.8, 0.2), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)), + ((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)), + ((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)))) + new_policy_probs_worse = mx.nd.array((((0.9, 0.1), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)), + ((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)), + ((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)))) + new_policy_probs_better = mx.nd.array((((0.2, 0.8), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)), + ((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)), + ((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)))) + + # check lower loss for policy with better probabilities: + # i.e. higher probability on high advantage actions, low probability on low advantage actions. + loss_fn = ClippedPPOLossDiscrete(num_actions=num_actions, + clip_likelihood_ratio_using_epsilon=0.2) + loss_fn.initialize() + + clip_param_rescaler = mx.nd.array((1,)) + + loss_worse = loss_fn(new_policy_probs_worse, actions, old_policy_probs, clip_param_rescaler, advantages) + loss_better = loss_fn(new_policy_probs_better, actions, old_policy_probs, clip_param_rescaler, advantages) + + assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO) + loss_worse_val = loss_worse[0] + assert loss_worse_val.ndim == 1 + assert loss_worse_val.shape[0] == 1 + assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO) + loss_better_val = loss_better[0] + assert loss_better_val.ndim == 1 + assert loss_better_val.shape[0] == 1 + assert loss_worse_val > loss_better_val + + +@pytest.mark.unit_test +def test_clipped_ppo_loss_discrete_weight(): + actions = mx.nd.array((0, 1, 0)) + advantages = mx.nd.array((-2, 2, 1)) + old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6))) + new_policy_probs = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6))) + + clip_param_rescaler = mx.nd.array((1,)) + loss_fn = ClippedPPOLossDiscrete(num_actions=2, + clip_likelihood_ratio_using_epsilon=0.2) + loss_fn.initialize() + loss = loss_fn(new_policy_probs, actions, old_policy_probs, clip_param_rescaler, advantages) + loss_fn_weighted = ClippedPPOLossDiscrete(num_actions=2, + clip_likelihood_ratio_using_epsilon=0.2, + weight=0.5) + loss_fn_weighted.initialize() + loss_weighted = loss_fn_weighted(new_policy_probs, actions, old_policy_probs, clip_param_rescaler, advantages) + assert loss[0] == loss_weighted[0] * 2 + + +@pytest.mark.unit_test +def test_clipped_ppo_loss_discrete_hybridize(): + loss_fn = ClippedPPOLossDiscrete(num_actions=2, + clip_likelihood_ratio_using_epsilon=0.2) + loss_fn.initialize() + loss_fn.hybridize() + actions = mx.nd.array((0, 1, 0)) + advantages = mx.nd.array((-2, 2, 1)) + old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6))) + new_policy_probs = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6))) + clip_param_rescaler = mx.nd.array((1,)) + + loss = loss_fn(new_policy_probs, actions, old_policy_probs, clip_param_rescaler, advantages) + assert loss[0] == mx.nd.array((-0.142857153,)) + + +@pytest.mark.unit_test +def test_discrete_ppo_head(): + head = DiscretePPOHead(num_actions=2) + head.initialize() + middleware_data = mx.nd.random.uniform(shape=(10, 100)) + probs = head(middleware_data) + assert probs.ndim == 2 # (batch_size, num_actions) + assert probs.shape[0] == 10 # since batch_size is 10 + assert probs.shape[1] == 2 # since num_actions is 2 + + +@pytest.mark.unit_test +def test_ppo_head(): + agent_parameters = ClippedPPOAgentParameters() + num_actions = 5 + action_space = DiscreteActionSpace(num_actions=num_actions) + spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None) + head = PPOHead(agent_parameters=agent_parameters, + spaces=spaces, + network_name="test_ppo_head") + + head.initialize() + + batch_size = 15 + middleware_data = mx.nd.random.uniform(shape=(batch_size, 100)) + probs = head(middleware_data) + assert probs.ndim == 2 # (batch_size, num_actions) + assert probs.shape[0] == batch_size + assert probs.shape[1] == num_actions diff --git a/rl_coach/tests/architectures/mxnet_components/heads/test_ppo_v_head.py b/rl_coach/tests/architectures/mxnet_components/heads/test_ppo_v_head.py new file mode 100644 index 0000000..abd7016 --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/heads/test_ppo_v_head.py @@ -0,0 +1,90 @@ +import mxnet as mx +import os +import pytest +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +from rl_coach.architectures.mxnet_components.heads.ppo_v_head import PPOVHead, PPOVHeadLoss +from rl_coach.agents.clipped_ppo_agent import ClippedPPOAlgorithmParameters, ClippedPPOAgentParameters +from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace + + +@pytest.mark.unit_test +def test_ppo_v_head_loss_batch(): + loss_fn = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.1) + total_return = mx.nd.array((5, -3, 0)) + old_policy_values = mx.nd.array((3, -1, -1)) + new_policy_values_worse = mx.nd.array((2, 0, -1)) + new_policy_values_better = mx.nd.array((4, -2, -1)) + + loss_worse = loss_fn(new_policy_values_worse, old_policy_values, total_return) + loss_better = loss_fn(new_policy_values_better, old_policy_values, total_return) + + assert len(loss_worse) == 1 # (LOSS) + loss_worse_val = loss_worse[0] + assert loss_worse_val.ndim == 1 + assert loss_worse_val.shape[0] == 1 + assert len(loss_better) == 1 # (LOSS) + loss_better_val = loss_better[0] + assert loss_better_val.ndim == 1 + assert loss_better_val.shape[0] == 1 + assert loss_worse_val > loss_better_val + + +@pytest.mark.unit_test +def test_ppo_v_head_loss_batch_time(): + loss_fn = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.1) + total_return = mx.nd.array(((3, 1, 1, 0), + (1, 0, 0, 1), + (3, 0, 1, 0))) + old_policy_values = mx.nd.array(((2, 1, 1, 0), + (1, 0, 0, 1), + (0, 0, 1, 0))) + new_policy_values_worse = mx.nd.array(((2, 1, 1, 0), + (1, 0, 0, 1), + (2, 0, 1, 0))) + new_policy_values_better = mx.nd.array(((3, 1, 1, 0), + (1, 0, 0, 1), + (2, 0, 1, 0))) + + loss_worse = loss_fn(new_policy_values_worse, old_policy_values, total_return) + loss_better = loss_fn(new_policy_values_better, old_policy_values, total_return) + + assert len(loss_worse) == 1 # (LOSS) + loss_worse_val = loss_worse[0] + assert loss_worse_val.ndim == 1 + assert loss_worse_val.shape[0] == 1 + assert len(loss_better) == 1 # (LOSS) + loss_better_val = loss_better[0] + assert loss_better_val.ndim == 1 + assert loss_better_val.shape[0] == 1 + assert loss_worse_val > loss_better_val + + +@pytest.mark.unit_test +def test_ppo_v_head_loss_weight(): + total_return = mx.nd.array((5, -3, 0)) + old_policy_values = mx.nd.array((3, -1, -1)) + new_policy_values = mx.nd.array((4, -2, -1)) + loss_fn = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.2, weight=1) + loss = loss_fn(new_policy_values, old_policy_values, total_return) + loss_fn_weighted = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.2, weight=0.5) + loss_weighted = loss_fn_weighted(new_policy_values, old_policy_values, total_return) + assert loss[0].sum() == loss_weighted[0].sum() * 2 + + +@pytest.mark.unit_test +def test_ppo_v_head(): + agent_parameters = ClippedPPOAgentParameters() + action_space = DiscreteActionSpace(num_actions=5) + spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None) + value_net = PPOVHead(agent_parameters=agent_parameters, + spaces=spaces, + network_name="test_ppo_v_head") + value_net.initialize() + batch_size = 15 + middleware_data = mx.nd.random.uniform(shape=(batch_size, 100)) + values = value_net(middleware_data) + assert values.ndim == 1 # (batch_size) + assert values.shape[0] == batch_size diff --git a/rl_coach/tests/architectures/mxnet_components/heads/test_q_head.py b/rl_coach/tests/architectures/mxnet_components/heads/test_q_head.py new file mode 100644 index 0000000..542b3b4 --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/heads/test_q_head.py @@ -0,0 +1,60 @@ +import mxnet as mx +import os +import pytest +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +from rl_coach.architectures.mxnet_components.heads.q_head import QHead, QHeadLoss +from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters +from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace + + + +@pytest.mark.unit_test +def test_q_head_loss(): + loss_fn = QHeadLoss() + # example with batch_size of 3, and num_actions of 2 + target_q_values = mx.nd.array(((3, 5), (-1, -2), (0, 2))) + pred_q_values_worse = mx.nd.array(((6, 5), (-1, -2), (0, 2))) + pred_q_values_better = mx.nd.array(((4, 5), (-2, -2), (1, 2))) + loss_worse = loss_fn(pred_q_values_worse, target_q_values) + loss_better = loss_fn(pred_q_values_better, target_q_values) + assert len(loss_worse) == 1 # (LOSS) + loss_worse_val = loss_worse[0] + assert loss_worse_val.ndim == 1 + assert loss_worse_val.shape[0] == 1 + assert len(loss_better) == 1 # (LOSS) + loss_better_val = loss_better[0] + assert loss_better_val.ndim == 1 + assert loss_better_val.shape[0] == 1 + assert loss_worse_val > loss_better_val + + +@pytest.mark.unit_test +def test_v_head_loss_weight(): + target_q_values = mx.nd.array(((3, 5), (-1, -2), (0, 2))) + pred_q_values = mx.nd.array(((4, 5), (-2, -2), (1, 2))) + loss_fn = QHeadLoss() + loss = loss_fn(pred_q_values, target_q_values) + loss_fn_weighted = QHeadLoss(weight=0.5) + loss_weighted = loss_fn_weighted(pred_q_values, target_q_values) + assert loss[0] == loss_weighted[0]*2 + + +@pytest.mark.unit_test +def test_ppo_v_head(): + agent_parameters = ClippedPPOAgentParameters() + num_actions = 5 + action_space = DiscreteActionSpace(num_actions=num_actions) + spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None) + value_net = QHead(agent_parameters=agent_parameters, + spaces=spaces, + network_name="test_q_head") + value_net.initialize() + batch_size = 15 + middleware_data = mx.nd.random.uniform(shape=(batch_size, 100)) + values = value_net(middleware_data) + assert values.ndim == 2 # (batch_size, num_actions) + assert values.shape[0] == batch_size + assert values.shape[1] == num_actions \ No newline at end of file diff --git a/rl_coach/tests/architectures/mxnet_components/heads/test_v_head.py b/rl_coach/tests/architectures/mxnet_components/heads/test_v_head.py new file mode 100644 index 0000000..271d661 --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/heads/test_v_head.py @@ -0,0 +1,57 @@ +import mxnet as mx +import os +import pytest +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +from rl_coach.architectures.mxnet_components.heads.v_head import VHead, VHeadLoss +from rl_coach.agents.clipped_ppo_agent import ClippedPPOAlgorithmParameters, ClippedPPOAgentParameters +from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace + + + +@pytest.mark.unit_test +def test_v_head_loss(): + loss_fn = VHeadLoss() + target_values = mx.nd.array((3, -1, 0)) + pred_values_worse = mx.nd.array((0, 0, 1)) + pred_values_better = mx.nd.array((2, -1, 0)) + loss_worse = loss_fn(pred_values_worse, target_values) + loss_better = loss_fn(pred_values_better, target_values) + assert len(loss_worse) == 1 # (LOSS) + loss_worse_val = loss_worse[0] + assert loss_worse_val.ndim == 1 + assert loss_worse_val.shape[0] == 1 + assert len(loss_better) == 1 # (LOSS) + loss_better_val = loss_better[0] + assert loss_better_val.ndim == 1 + assert loss_better_val.shape[0] == 1 + assert loss_worse_val > loss_better_val + + +@pytest.mark.unit_test +def test_v_head_loss_weight(): + target_values = mx.nd.array((3, -1, 0)) + pred_values = mx.nd.array((0, 0, 1)) + loss_fn = VHeadLoss() + loss = loss_fn(pred_values, target_values) + loss_fn_weighted = VHeadLoss(weight=0.5) + loss_weighted = loss_fn_weighted(pred_values, target_values) + assert loss[0] == loss_weighted[0]*2 + + +@pytest.mark.unit_test +def test_ppo_v_head(): + agent_parameters = ClippedPPOAgentParameters() + action_space = DiscreteActionSpace(num_actions=5) + spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None) + value_net = VHead(agent_parameters=agent_parameters, + spaces=spaces, + network_name="test_v_head") + value_net.initialize() + batch_size = 15 + middleware_data = mx.nd.random.uniform(shape=(batch_size, 100)) + values = value_net(middleware_data) + assert values.ndim == 1 # (batch_size) + assert values.shape[0] == batch_size \ No newline at end of file diff --git a/rl_coach/tests/architectures/mxnet_components/middlewares/__init__.py b/rl_coach/tests/architectures/mxnet_components/middlewares/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl_coach/tests/architectures/mxnet_components/middlewares/test_fc_middleware.py b/rl_coach/tests/architectures/mxnet_components/middlewares/test_fc_middleware.py new file mode 100644 index 0000000..3a0807b --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/middlewares/test_fc_middleware.py @@ -0,0 +1,22 @@ +import mxnet as mx +import os +import pytest +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +from rl_coach.base_parameters import MiddlewareScheme +from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters +from rl_coach.architectures.mxnet_components.middlewares.fc_middleware import FCMiddleware + + +@pytest.mark.unit_test +def test_fc_middleware(): + params = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium) + mid = FCMiddleware(params=params) + mid.initialize() + embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 100)) + output = mid(embedded_data) + assert output.ndim == 2 # since last block was flatten + assert output.shape[0] == 10 # since batch_size is 10 + assert output.shape[1] == 512 # since last layer of middleware (middle scheme) had 512 units diff --git a/rl_coach/tests/architectures/mxnet_components/middlewares/test_lstm_middleware.py b/rl_coach/tests/architectures/mxnet_components/middlewares/test_lstm_middleware.py new file mode 100644 index 0000000..076932c --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/middlewares/test_lstm_middleware.py @@ -0,0 +1,25 @@ +import mxnet as mx +import os +import pytest +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +from rl_coach.base_parameters import MiddlewareScheme +from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters +from rl_coach.architectures.mxnet_components.middlewares.lstm_middleware import LSTMMiddleware + + +@pytest.mark.unit_test +def test_lstm_middleware(): + params = LSTMMiddlewareParameters(number_of_lstm_cells=25, scheme=MiddlewareScheme.Medium) + mid = LSTMMiddleware(params=params) + mid.initialize() + # NTC + embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 15, 20)) + # NTC -> TNC + output = mid(embedded_data) + assert output.ndim == 3 # since last block was flatten + assert output.shape[0] == 15 # since t is 15 + assert output.shape[1] == 10 # since batch_size is 10 + assert output.shape[2] == 25 # since number_of_lstm_cells is 25 diff --git a/rl_coach/tests/architectures/mxnet_components/test_utils.py b/rl_coach/tests/architectures/mxnet_components/test_utils.py new file mode 100644 index 0000000..0af729a --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/test_utils.py @@ -0,0 +1,144 @@ +import pytest + +import mxnet as mx +from mxnet import nd +import numpy as np + +from rl_coach.architectures.mxnet_components.utils import * + + +@pytest.mark.unit_test +def test_to_mx_ndarray(): + # scalar + assert to_mx_ndarray(1.2) == nd.array([1.2]) + # list of one scalar + assert to_mx_ndarray([1.2]) == [nd.array([1.2])] + # list of multiple scalars + assert to_mx_ndarray([1.2, 3.4]) == [nd.array([1.2]), nd.array([3.4])] + # list of lists of scalars + assert to_mx_ndarray([[1.2], [3.4]]) == [[nd.array([1.2])], [nd.array([3.4])]] + # numpy + assert np.array_equal(to_mx_ndarray(np.array([[1.2], [3.4]])).asnumpy(), nd.array([[1.2], [3.4]]).asnumpy()) + # tuple + assert to_mx_ndarray(((1.2,), (3.4,))) == ((nd.array([1.2]),), (nd.array([3.4]),)) + + +@pytest.mark.unit_test +def test_asnumpy_or_asscalar(): + # scalar float32 + assert asnumpy_or_asscalar(nd.array([1.2])) == np.float32(1.2) + # scalar int32 + assert asnumpy_or_asscalar(nd.array([2], dtype=np.int32)) == np.int32(2) + # list of one scalar + assert asnumpy_or_asscalar([nd.array([1.2])]) == [np.float32(1.2)] + # list of multiple scalars + assert asnumpy_or_asscalar([nd.array([1.2]), nd.array([3.4])]) == [np.float32([1.2]), np.float32([3.4])] + # list of lists of scalars + assert asnumpy_or_asscalar([[nd.array([1.2])], [nd.array([3.4])]]) == [[np.float32([1.2])], [np.float32([3.4])]] + # tensor + assert np.array_equal(asnumpy_or_asscalar(nd.array([[1.2], [3.4]])), np.array([[1.2], [3.4]], dtype=np.float32)) + # tuple + assert (asnumpy_or_asscalar(((nd.array([1.2]),), (nd.array([3.4]),))) == + ((np.array([1.2], dtype=np.float32),), (np.array([3.4], dtype=np.float32),))) + + +@pytest.mark.unit_test +def test_global_norm(): + data = list() + for i in range(1, 6): + data.append(np.ones((i * 10, i * 10)) * i) + gnorm = np.asscalar(np.sqrt(sum([np.sum(np.square(d)) for d in data]))) + assert np.isclose(gnorm, global_norm([nd.array(d) for d in data]).asscalar()) + + +@pytest.mark.unit_test +def test_split_outputs_per_head(): + class TestHead: + def __init__(self, num_outputs): + self.num_outputs = num_outputs + + assert split_outputs_per_head((1, 2, 3, 4), [TestHead(2), TestHead(1), TestHead(1)]) == [[1, 2], [3], [4]] + + +class DummySchema: + def __init__(self, num_head_outputs, num_agent_inputs, num_targets): + self.head_outputs = ['head_output_{}'.format(i) for i in range(num_head_outputs)] + self.agent_inputs = ['agent_input_{}'.format(i) for i in range(num_agent_inputs)] + self.targets = ['target_{}'.format(i) for i in range(num_targets)] + + +class DummyLoss: + def __init__(self, num_head_outputs, num_agent_inputs, num_targets): + self.input_schema = DummySchema(num_head_outputs, num_agent_inputs, num_targets) + + +@pytest.mark.unit_test +def test_split_targets_per_loss(): + assert split_targets_per_loss([1, 2, 3, 4], + [DummyLoss(10, 100, 2), DummyLoss(20, 200, 1), DummyLoss(30, 300, 1)]) == \ + [[1, 2], [3], [4]] + + +@pytest.mark.unit_test +def test_get_loss_agent_inputs(): + input_dict = {'output_0_0': [1, 2], 'output_0_1': [3, 4], 'output_1_0': [5]} + assert get_loss_agent_inputs(input_dict, 0, DummyLoss(10, 2, 100)) == [[1, 2], [3, 4]] + assert get_loss_agent_inputs(input_dict, 1, DummyLoss(20, 1, 200)) == [[5]] + + +@pytest.mark.unit_test +def test_align_loss_args(): + class TestLossFwd(DummyLoss): + def __init__(self, num_targets, num_agent_inputs, num_head_outputs): + super(TestLossFwd, self).__init__(num_targets, num_agent_inputs, num_head_outputs) + + def loss_forward(self, F, head_output_2, head_output_1, agent_input_2, target_0, agent_input_1, param1, param2): + pass + + assert align_loss_args([1, 2, 3], [4, 5, 6, 7], [8, 9], TestLossFwd(3, 4, 2)) == [3, 2, 6, 8, 5] + + +@pytest.mark.unit_test +def test_to_tuple(): + assert to_tuple(123) == (123,) + assert to_tuple((1, 2, 3)) == (1, 2, 3) + assert to_tuple([1, 2, 3]) == (1, 2, 3) + + +@pytest.mark.unit_test +def test_to_list(): + assert to_list(123) == [123] + assert to_list((1, 2, 3)) == [1, 2, 3] + assert to_list([1, 2, 3]) == [1, 2, 3] + + +@pytest.mark.unit_test +def test_loss_output_dict(): + assert loss_output_dict([1, 2, 3], ['loss', 'loss', 'reg']) == {'loss': [1, 2], 'reg': [3]} + + +@pytest.mark.unit_test +def test_clip_grad(): + a = np.array([1, 2, -3]) + b = np.array([4, 5, -6]) + clip = 2 + gscale = np.minimum(1.0, clip / np.sqrt(np.sum(np.square(a)) + np.sum(np.square(b)))) + for lhs, rhs in zip(clip_grad([nd.array(a), nd.array(b)], GradientClippingMethod.ClipByGlobalNorm, clip_val=clip), + [a, b]): + assert np.allclose(lhs.asnumpy(), rhs * gscale) + for lhs, rhs in zip(clip_grad([nd.array(a), nd.array(b)], GradientClippingMethod.ClipByValue, clip_val=clip), + [a, b]): + assert np.allclose(lhs.asnumpy(), np.clip(rhs, -clip, clip)) + for lhs, rhs in zip(clip_grad([nd.array(a), nd.array(b)], GradientClippingMethod.ClipByNorm, clip_val=clip), + [a, b]): + scale = np.minimum(1.0, clip / np.sqrt(np.sum(np.square(rhs)))) + assert np.allclose(lhs.asnumpy(), rhs * scale) + + +@pytest.mark.unit_test +def test_hybrid_clip(): + x = mx.nd.array((0.5, 1.5, 2.5)) + a = mx.nd.array((1,)) + b = mx.nd.array((2,)) + clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b) + assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all() diff --git a/setup.py b/setup.py index e77cd70..70790b2 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,17 @@ if not using_GPU: else: install_requires.append('tensorflow-gpu==1.9.0') +# Framework-specific dependencies. +extras = { + 'mxnet': ['mxnet-cu90mkl>=1.3.0'] +} + +all_deps = [] +for group_name in extras: + all_deps += extras[group_name] +extras['all'] = all_deps + + setup( name='rl-coach', version='0.10.0', @@ -78,6 +89,7 @@ setup( packages=find_packages(), python_requires=">=3.5.*", install_requires=install_requires, + extras_require=extras, package_data={'rl_coach': ['dashboard_components/*.css', 'environments/doom/*.cfg', 'environments/doom/*.wad',