mirror of
https://github.com/gryf/coach.git
synced 2026-04-16 12:33:34 +02:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
14
rl_coach/architectures/mxnet_components/heads/__init__.py
Normal file
14
rl_coach/architectures/mxnet_components/heads/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .head import Head, HeadLoss
|
||||
from .q_head import QHead
|
||||
from .ppo_head import PPOHead
|
||||
from .ppo_v_head import PPOVHead
|
||||
from .v_head import VHead
|
||||
|
||||
__all__ = [
|
||||
'Head',
|
||||
'HeadLoss',
|
||||
'QHead',
|
||||
'PPOHead',
|
||||
'PPOVHead',
|
||||
'VHead'
|
||||
]
|
||||
181
rl_coach/architectures/mxnet_components/heads/head.py
Normal file
181
rl_coach/architectures/mxnet_components/heads/head.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from typing import Dict, List, Union, Tuple
|
||||
|
||||
from mxnet.gluon import nn, loss
|
||||
from mxnet.ndarray import NDArray
|
||||
from mxnet.symbol import Symbol
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
LOSS_OUT_TYPE_LOSS = 'loss'
|
||||
LOSS_OUT_TYPE_REGULARIZATION = 'regularization'
|
||||
|
||||
|
||||
class LossInputSchema(object):
|
||||
"""
|
||||
Helper class to contain schema for loss hybrid_forward input
|
||||
"""
|
||||
def __init__(self, head_outputs: List[str], agent_inputs: List[str], targets: List[str]):
|
||||
"""
|
||||
:param head_outputs: list of argument names in hybrid_forward that are outputs of the head.
|
||||
The order and number MUST MATCH the output from the head.
|
||||
:param agent_inputs: list of argument names in hybrid_forward that are inputs from the agent.
|
||||
The order and number MUST MATCH `output_<head_type_idx>_<order>` for this head.
|
||||
:param targets: list of argument names in hybrid_forward that are targets for the loss.
|
||||
The order and number MUST MATCH targets passed from the agent.
|
||||
"""
|
||||
self._head_outputs = head_outputs
|
||||
self._agent_inputs = agent_inputs
|
||||
self._targets = targets
|
||||
|
||||
@property
|
||||
def head_outputs(self):
|
||||
return self._head_outputs
|
||||
|
||||
@property
|
||||
def agent_inputs(self):
|
||||
return self._agent_inputs
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return self._targets
|
||||
|
||||
|
||||
class HeadLoss(loss.Loss):
|
||||
"""
|
||||
ABC for loss functions of each head. Child class must implement input_schema() and loss_forward()
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(HeadLoss, self).__init__(*args, **kwargs)
|
||||
self._output_schema = None # type: List[str]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
"""
|
||||
:return: schema for input of hybrid_forward. Read docstring for LossInputSchema for details.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_schema(self) -> List[str]:
|
||||
"""
|
||||
:return: schema for output of hybrid_forward. Must contain 'loss' and 'regularization' keys at least once.
|
||||
The order and total number must match that of returned values from the loss. 'loss' and 'regularization'
|
||||
are special keys. Any other string is treated as auxiliary outputs and must include match auxiliary
|
||||
fetch names returned by the head.
|
||||
"""
|
||||
return self._output_schema
|
||||
|
||||
def forward(self, *args):
|
||||
"""
|
||||
Override forward() so that number of outputs can be checked against the schema
|
||||
"""
|
||||
outputs = super(HeadLoss, self).forward(*args)
|
||||
if isinstance(outputs, tuple) or isinstance(outputs, list):
|
||||
num_outputs = len(outputs)
|
||||
else:
|
||||
assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol)
|
||||
num_outputs = 1
|
||||
assert num_outputs == len(self.output_schema), "Number of outputs don't match schema ({} != {})".format(
|
||||
num_outputs, len(self.output_schema))
|
||||
return outputs
|
||||
|
||||
def _loss_output(self, outputs: List[Tuple[Union[NDArray, Symbol], str]]):
|
||||
"""
|
||||
Must be called on the output from hybrid_forward().
|
||||
Saves the returned output as the schema and returns output values in a list
|
||||
:return: list of output values
|
||||
"""
|
||||
output_schema = [o[1] for o in outputs]
|
||||
assert self._output_schema is None or self._output_schema == output_schema
|
||||
self._output_schema = output_schema
|
||||
return tuple(o[0] for o in outputs)
|
||||
|
||||
def hybrid_forward(self, F, x, *args, **kwargs):
|
||||
"""
|
||||
Passes the cal to loss_forward() and constructs output schema from its output by calling loss_output()
|
||||
"""
|
||||
return self._loss_output(self.loss_forward(F, x, *args, **kwargs))
|
||||
|
||||
def loss_forward(self, F, x, *args, **kwargs) -> List[Tuple[Union[NDArray, Symbol], str]]:
|
||||
"""
|
||||
Similar to hybrid_forward, but returns list of (NDArray, type_str)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Head(nn.HybridBlock):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition,
|
||||
network_name: str, head_type_idx: int=0, loss_weight: float=1., is_local: bool=True,
|
||||
activation_function: str='relu', dense_layer: None=None):
|
||||
"""
|
||||
A head is the final part of the network. It takes the embedding from the middleware embedder and passes it
|
||||
through a neural network to produce the output of the network. There can be multiple heads in a network, and
|
||||
each one has an assigned loss function. The heads are algorithm dependent.
|
||||
|
||||
:param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon
|
||||
and beta_entropy.
|
||||
:param spaces: containing action spaces used for defining size of network output.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
"""
|
||||
super(Head, self).__init__()
|
||||
self.head_type_idx = head_type_idx
|
||||
self.network_name = network_name
|
||||
self.loss_weight = loss_weight
|
||||
self.is_local = is_local
|
||||
self.ap = agent_parameters
|
||||
self.spaces = spaces
|
||||
self.return_type = None
|
||||
self.activation_function = activation_function
|
||||
self.dense_layer = dense_layer
|
||||
self._num_outputs = None
|
||||
|
||||
def loss(self) -> HeadLoss:
|
||||
"""
|
||||
Returns loss block to be used for specific head implementation.
|
||||
|
||||
:return: loss block (can be called as function) for outputs returned by the head network.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
""" Returns number of outputs that forward() call will return
|
||||
|
||||
:return:
|
||||
"""
|
||||
assert self._num_outputs is not None, 'must call forward() once to configure number of outputs'
|
||||
return self._num_outputs
|
||||
|
||||
def forward(self, *args):
|
||||
"""
|
||||
Override forward() so that number of outputs can be automatically set
|
||||
"""
|
||||
outputs = super(Head, self).forward(*args)
|
||||
if isinstance(outputs, tuple):
|
||||
num_outputs = len(outputs)
|
||||
else:
|
||||
assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol)
|
||||
num_outputs = 1
|
||||
if self._num_outputs is None:
|
||||
self._num_outputs = num_outputs
|
||||
else:
|
||||
assert self._num_outputs == num_outputs, 'Number of outputs cannot change ({} != {})'.format(
|
||||
self._num_outputs, num_outputs)
|
||||
assert self._num_outputs == len(self.loss().input_schema.head_outputs)
|
||||
return outputs
|
||||
|
||||
def hybrid_forward(self, F, x, *args, **kwargs):
|
||||
"""
|
||||
Used for forward pass through head network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final output of network, that will be used in loss calculations.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
669
rl_coach/architectures/mxnet_components/heads/ppo_head.py
Normal file
669
rl_coach/architectures/mxnet_components/heads/ppo_head.py
Normal file
@@ -0,0 +1,669 @@
|
||||
from typing import List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import math
|
||||
import mxnet as mx
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
|
||||
from rl_coach.utils import eps
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||
from rl_coach.architectures.mxnet_components.utils import hybrid_clip
|
||||
|
||||
|
||||
LOSS_OUT_TYPE_KL = 'kl_divergence'
|
||||
LOSS_OUT_TYPE_ENTROPY = 'entropy'
|
||||
LOSS_OUT_TYPE_LIKELIHOOD_RATIO = 'likelihood_ratio'
|
||||
LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO = 'clipped_likelihood_ratio'
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class MultivariateNormalDist:
|
||||
def __init__(self,
|
||||
num_var: int,
|
||||
mean: nd_sym_type,
|
||||
sigma: nd_sym_type,
|
||||
F: ModuleType=mx.nd) -> None:
|
||||
"""
|
||||
Distribution object for Multivariate Normal. Works with batches.
|
||||
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
|
||||
mean, sigma and data for log_prob must all include a time_step dimension.
|
||||
|
||||
:param num_var: number of variables in distribution
|
||||
:param mean: mean for each variable,
|
||||
of shape (num_var) or
|
||||
of shape (batch_size, num_var) or
|
||||
of shape (batch_size, time_step, num_var).
|
||||
:param sigma: covariance matrix,
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
"""
|
||||
self.num_var = num_var
|
||||
self.mean = mean
|
||||
self.sigma = sigma
|
||||
self.F = F
|
||||
|
||||
def inverse_using_cholesky(self, matrix: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Calculate inverses for a batch of matrices using Cholesky decomposition method.
|
||||
|
||||
:param matrix: matrix (or matrices) to invert,
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
:return: inverted matrix (or matrices),
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
"""
|
||||
cholesky_factor = self.F.linalg.potrf(matrix)
|
||||
return self.F.linalg.potri(cholesky_factor)
|
||||
|
||||
def log_det(self, matrix: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Calculate log of the determinant for a batch of matrices using Cholesky decomposition method.
|
||||
|
||||
:param matrix: matrix (or matrices) to invert,
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
:return: inverted matrix (or matrices),
|
||||
of shape (num_var, num_var) or
|
||||
of shape (batch_size, num_var, num_var) or
|
||||
of shape (batch_size, time_step, num_var, num_var).
|
||||
"""
|
||||
cholesky_factor = self.F.linalg.potrf(matrix)
|
||||
return 2 * self.F.linalg.sumlogdiag(cholesky_factor)
|
||||
|
||||
def log_prob(self, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Calculate the log probability of data given the current distribution.
|
||||
|
||||
See http://www.notenoughthoughts.net/posts/normal-log-likelihood-gradient.html
|
||||
and https://discuss.mxnet.io/t/multivariate-gaussian-log-density-operator/1169/7
|
||||
|
||||
:param x: input data,
|
||||
of shape (num_var) or
|
||||
of shape (batch_size, num_var) or
|
||||
of shape (batch_size, time_step, num_var).
|
||||
:return: log_probability,
|
||||
of shape (1) or
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
"""
|
||||
a = (self.num_var / 2) * math.log(2 * math.pi)
|
||||
log_det_sigma = self.log_det(self.sigma)
|
||||
b = (1 / 2) * log_det_sigma
|
||||
sigma_inv = self.inverse_using_cholesky(self.sigma)
|
||||
# deviation from mean, and dev_t is equivalent to transpose on last two dims.
|
||||
dev = (x - self.mean).expand_dims(-1)
|
||||
dev_t = (x - self.mean).expand_dims(-2)
|
||||
|
||||
# since batch_dot only works with ndarrays with ndim of 3,
|
||||
# and we could have ndarrays with ndim of 4,
|
||||
# we flatten batch_size and time_step into single dim.
|
||||
dev_flat = dev.reshape(shape=(-1, 0, 0), reverse=1)
|
||||
sigma_inv_flat = sigma_inv.reshape(shape=(-1, 0, 0), reverse=1)
|
||||
dev_t_flat = dev_t.reshape(shape=(-1, 0, 0), reverse=1)
|
||||
c = (1 / 2) * self.F.batch_dot(self.F.batch_dot(dev_t_flat, sigma_inv_flat), dev_flat)
|
||||
# and now reshape back to (batch_size, time_step) if required.
|
||||
c = c.reshape_like(b)
|
||||
|
||||
log_likelihood = -a - b - c
|
||||
return log_likelihood
|
||||
|
||||
def entropy(self) -> nd_sym_type:
|
||||
"""
|
||||
Calculate entropy of current distribution.
|
||||
|
||||
See http://www.nowozin.net/sebastian/blog/the-entropy-of-a-normal-distribution.html
|
||||
:return: entropy,
|
||||
of shape (1) or
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
"""
|
||||
# todo: check if differential entropy is correct
|
||||
log_det_sigma = self.log_det(self.sigma)
|
||||
return (self.num_var / 2) + ((self.num_var / 2) * math.log(2 * math.pi)) + ((1 / 2) * log_det_sigma)
|
||||
|
||||
def kl_div(self, alt_dist) -> nd_sym_type:
|
||||
"""
|
||||
Calculated KL-Divergence with another MultivariateNormalDist distribution
|
||||
See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
|
||||
Specifically https://wikimedia.org/api/rest_v1/media/math/render/svg/a3bf3b4917bd1fcb8be48d6d6139e2e387bdc7d3
|
||||
|
||||
:param alt_dist: alternative distribution used for kl divergence calculation
|
||||
:type alt_dist: MultivariateNormalDist
|
||||
:return: KL-Divergence, of shape (1,)
|
||||
"""
|
||||
sigma_a_inv = self.F.linalg.potri(self.F.linalg.potrf(self.sigma))
|
||||
sigma_b_inv = self.F.linalg.potri(self.F.linalg.potrf(alt_dist.sigma))
|
||||
term1a = mx.nd.batch_dot(sigma_b_inv, self.sigma)
|
||||
# sum of diagonal for batch of matrices
|
||||
term1 = (self.F.eye(self.num_var).broadcast_like(term1a) * term1a).sum(axis=-1).sum(axis=-1)
|
||||
mean_diff = (alt_dist.mean - self.mean).expand_dims(-1)
|
||||
mean_diff_t = (alt_dist.mean - self.mean).expand_dims(-2)
|
||||
term2 = self.F.batch_dot(self.F.batch_dot(mean_diff_t, sigma_b_inv), mean_diff).reshape_like(term1)
|
||||
term3 = (2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(alt_dist.sigma))) -\
|
||||
(2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(self.sigma)))
|
||||
return 0.5 * (term1 + term2 - self.num_var + term3)
|
||||
|
||||
|
||||
|
||||
class CategoricalDist:
|
||||
def __init__(self, n_classes: int, probs: nd_sym_type, F: ModuleType=mx.nd) -> None:
|
||||
"""
|
||||
Distribution object for Categorical data.
|
||||
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
|
||||
mean, sigma and data for log_prob must all include a time_step dimension.
|
||||
|
||||
:param n_classes: number of classes in distribution
|
||||
:param probs: probabilities for each class,
|
||||
of shape (n_classes),
|
||||
of shape (batch_size, n_classes) or
|
||||
of shape (batch_size, time_step, n_classes)
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
"""
|
||||
self.n_classes = n_classes
|
||||
self.probs = probs
|
||||
self.F = F
|
||||
|
||||
|
||||
def log_prob(self, actions: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Calculate the log probability of data given the current distribution.
|
||||
|
||||
:param actions: actions, with int8 data type,
|
||||
of shape (1) if probs was (n_classes),
|
||||
of shape (batch_size) if probs was (batch_size, n_classes) and
|
||||
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
|
||||
:return: log_probability,
|
||||
of shape (1) if probs was (n_classes),
|
||||
of shape (batch_size) if probs was (batch_size, n_classes) and
|
||||
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
|
||||
"""
|
||||
action_mask = actions.one_hot(depth=self.n_classes)
|
||||
action_probs = (self.probs * action_mask).sum(axis=-1)
|
||||
return action_probs.log()
|
||||
|
||||
def entropy(self) -> nd_sym_type:
|
||||
"""
|
||||
Calculate entropy of current distribution.
|
||||
|
||||
:return: entropy,
|
||||
of shape (1) if probs was (n_classes),
|
||||
of shape (batch_size) if probs was (batch_size, n_classes) and
|
||||
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
|
||||
"""
|
||||
# todo: look into numerical stability
|
||||
return -(self.probs.log()*self.probs).sum(axis=-1)
|
||||
|
||||
def kl_div(self, alt_dist) -> nd_sym_type:
|
||||
"""
|
||||
Calculated KL-Divergence with another Categorical distribution
|
||||
|
||||
:param alt_dist: alternative distribution used for kl divergence calculation
|
||||
:type alt_dist: CategoricalDist
|
||||
:return: KL-Divergence
|
||||
"""
|
||||
logits_a = self.probs.clip(a_min=eps, a_max=1 - eps).log()
|
||||
logits_b = alt_dist.probs.clip(a_min=eps, a_max=1 - eps).log()
|
||||
t = self.probs * (logits_a - logits_b)
|
||||
t = self.F.where(condition=(alt_dist.probs == 0), x=self.F.ones_like(alt_dist.probs) * math.inf, y=t)
|
||||
t = self.F.where(condition=(self.probs == 0), x=self.F.zeros_like(self.probs), y=t)
|
||||
return t.sum(axis=-1)
|
||||
|
||||
|
||||
class DiscretePPOHead(nn.HybridBlock):
|
||||
def __init__(self, num_actions: int) -> None:
|
||||
"""
|
||||
Head block for Discrete Proximal Policy Optimization, to calculate probabilities for each action given
|
||||
middleware representation of the environment state.
|
||||
|
||||
:param num_actions: number of actions in action space.
|
||||
"""
|
||||
super(DiscretePPOHead, self).__init__()
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False)
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through head network.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param x: middleware state representation,
|
||||
of shape (batch_size, in_channels) or
|
||||
of shape (batch_size, time_step, in_channels).
|
||||
:return: batch of probabilities for each action,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
"""
|
||||
policy_values = self.dense(x)
|
||||
policy_probs = F.softmax(policy_values)
|
||||
return policy_probs
|
||||
|
||||
|
||||
class ContinuousPPOHead(nn.HybridBlock):
|
||||
def __init__(self, num_actions: int) -> None:
|
||||
"""
|
||||
Head block for Continuous Proximal Policy Optimization, to calculate probabilities for each action given
|
||||
middleware representation of the environment state.
|
||||
|
||||
:param num_actions: number of actions in action space.
|
||||
"""
|
||||
super(ContinuousPPOHead, self).__init__()
|
||||
with self.name_scope():
|
||||
# todo: change initialization strategy
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False)
|
||||
# all samples (across batch, and time step) share the same covariance, which is learnt,
|
||||
# but since we assume the action probability variables are independent,
|
||||
# only the diagonal entries of the covariance matrix are specified.
|
||||
self.log_std = self.params.get('log_std',
|
||||
shape=num_actions,
|
||||
init=mx.init.Zero(),
|
||||
allow_deferred_init=True)
|
||||
# todo: is_local?
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, log_std: nd_sym_type) -> List[nd_sym_type]:
|
||||
"""
|
||||
Used for forward pass through head network.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param x: middleware state representation,
|
||||
of shape (batch_size, in_channels) or
|
||||
of shape (batch_size, time_step, in_channels).
|
||||
:return: batch of probabilities for each action,
|
||||
of shape (batch_size, action_mean) or
|
||||
of shape (batch_size, time_step, action_mean).
|
||||
"""
|
||||
policy_means = self.dense(x)
|
||||
policy_std = log_std.exp()
|
||||
return [policy_means, policy_std]
|
||||
|
||||
|
||||
class ClippedPPOLossDiscrete(HeadLoss):
|
||||
def __init__(self,
|
||||
num_actions: int,
|
||||
clip_likelihood_ratio_using_epsilon: float,
|
||||
beta: float=0,
|
||||
use_kl_regularization: bool=False,
|
||||
initial_kl_coefficient: float=1,
|
||||
kl_cutoff: float=0,
|
||||
high_kl_penalty_coefficient: float=1,
|
||||
weight: float=1,
|
||||
batch_axis: int=0) -> None:
|
||||
"""
|
||||
Loss for discrete version of Clipped PPO.
|
||||
|
||||
:param num_actions: number of actions in action space.
|
||||
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
|
||||
:param beta: loss coefficient applied to entropy
|
||||
:param use_kl_regularization: option to add kl divergence loss
|
||||
:param initial_kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
||||
:param kl_cutoff: threshold for using high_kl_penalty_coefficient
|
||||
:param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(ClippedPPOLossDiscrete, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
self.weight = weight
|
||||
self.num_actions = num_actions
|
||||
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
|
||||
self.beta = beta
|
||||
self.use_kl_regularization = use_kl_regularization
|
||||
self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0
|
||||
self.kl_coefficient = self.params.get('kl_coefficient',
|
||||
shape=(1,),
|
||||
init=mx.init.Constant([initial_kl_coefficient,]),
|
||||
differentiable=False)
|
||||
self.kl_cutoff = kl_cutoff
|
||||
self.high_kl_penalty_coefficient = high_kl_penalty_coefficient
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['new_policy_probs'],
|
||||
agent_inputs=['actions', 'old_policy_probs', 'clip_param_rescaler'],
|
||||
targets=['advantages']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
new_policy_probs: nd_sym_type,
|
||||
actions: nd_sym_type,
|
||||
old_policy_probs: nd_sym_type,
|
||||
clip_param_rescaler: nd_sym_type,
|
||||
advantages: nd_sym_type,
|
||||
kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
|
||||
new_policy_probs, old_policy_probs, actions and advantages all must include a time_step dimension.
|
||||
|
||||
NOTE: order of input arguments MUST NOT CHANGE because it matches the order
|
||||
parameters are passed in ppo_agent:train_network()
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param new_policy_probs: action probabilities predicted by DiscretePPOHead network,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param old_policy_probs: action probabilities for previous policy,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param actions: true actions taken during rollout,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
|
||||
:param advantages: change in state value after taking action (a.k.a advantage)
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
|
||||
old_policy_dist = CategoricalDist(self.num_actions, old_policy_probs, F=F)
|
||||
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
|
||||
|
||||
new_policy_dist = CategoricalDist(self.num_actions, new_policy_probs, F=F)
|
||||
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
|
||||
|
||||
entropy_loss = - self.beta * new_policy_dist.entropy().mean()
|
||||
|
||||
if self.use_kl_regularization:
|
||||
kl_div = old_policy_dist.kl_div(new_policy_dist).mean()
|
||||
weighted_kl_div = kl_coefficient * kl_div
|
||||
high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square()
|
||||
weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div
|
||||
kl_div_loss = weighted_kl_div + weighted_high_kl_div
|
||||
else:
|
||||
kl_div_loss = F.zeros(shape=(1,))
|
||||
|
||||
# working with log probs, so minus first, then exponential (same as division)
|
||||
likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp()
|
||||
|
||||
if self.clip_likelihood_ratio_using_epsilon is not None:
|
||||
# clipping of likelihood ratio
|
||||
min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
|
||||
max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
|
||||
|
||||
# can't use F.clip (with variable clipping bounds), hence custom implementation
|
||||
clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value)
|
||||
|
||||
# lower bound of original, and clipped versions or each scaled advantage
|
||||
# element-wise min between the two ndarrays
|
||||
unclipped_scaled_advantages = likelihood_ratio * advantages
|
||||
clipped_scaled_advantages = clipped_likelihood_ratio * advantages
|
||||
scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0)
|
||||
else:
|
||||
scaled_advantages = likelihood_ratio * advantages
|
||||
clipped_likelihood_ratio = F.zeros_like(likelihood_ratio)
|
||||
|
||||
# for each batch, calculate expectation of scaled_advantages across time steps,
|
||||
# but want code to work with data without time step too, so reshape to add timestep if doesn't exist.
|
||||
scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1))
|
||||
expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1)
|
||||
# want to maximize expected_scaled_advantages, add minus so can minimize.
|
||||
surrogate_loss = (-expected_scaled_advantages * self.weight).mean()
|
||||
|
||||
return [
|
||||
(surrogate_loss, LOSS_OUT_TYPE_LOSS),
|
||||
(entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION),
|
||||
(kl_div_loss, LOSS_OUT_TYPE_KL),
|
||||
(entropy_loss, LOSS_OUT_TYPE_ENTROPY),
|
||||
(likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO),
|
||||
(clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO)
|
||||
]
|
||||
|
||||
|
||||
class ClippedPPOLossContinuous(HeadLoss):
|
||||
def __init__(self,
|
||||
num_actions: int,
|
||||
clip_likelihood_ratio_using_epsilon: float,
|
||||
beta: float=0,
|
||||
use_kl_regularization: bool=False,
|
||||
initial_kl_coefficient: float=1,
|
||||
kl_cutoff: float=0,
|
||||
high_kl_penalty_coefficient: float=1,
|
||||
weight: float=1,
|
||||
batch_axis: int=0):
|
||||
"""
|
||||
Loss for continuous version of Clipped PPO.
|
||||
|
||||
:param num_actions: number of actions in action space.
|
||||
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
|
||||
:param beta: loss coefficient applied to entropy
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
:param use_kl_regularization: option to add kl divergence loss
|
||||
:param initial_kl_coefficient: initial loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
||||
:param kl_cutoff: threshold for using high_kl_penalty_coefficient
|
||||
:param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(ClippedPPOLossContinuous, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
self.weight = weight
|
||||
self.num_actions = num_actions
|
||||
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
|
||||
self.beta = beta
|
||||
self.use_kl_regularization = use_kl_regularization
|
||||
self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0
|
||||
self.kl_coefficient = self.params.get('kl_coefficient',
|
||||
shape=(1,),
|
||||
init=mx.init.Constant([initial_kl_coefficient,]),
|
||||
differentiable=False)
|
||||
self.kl_cutoff = kl_cutoff
|
||||
self.high_kl_penalty_coefficient = high_kl_penalty_coefficient
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['new_policy_means','new_policy_stds'],
|
||||
agent_inputs=['actions', 'old_policy_means', 'old_policy_stds', 'clip_param_rescaler'],
|
||||
targets=['advantages']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
new_policy_means: nd_sym_type,
|
||||
new_policy_stds: nd_sym_type,
|
||||
actions: nd_sym_type,
|
||||
old_policy_means: nd_sym_type,
|
||||
old_policy_stds: nd_sym_type,
|
||||
clip_param_rescaler: nd_sym_type,
|
||||
advantages: nd_sym_type,
|
||||
kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
|
||||
new_policy_means, old_policy_means, actions and advantages all must include a time_step dimension.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param new_policy_means: action means predicted by MultivariateNormalDist network,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param new_policy_stds: action standard deviation returned by head,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param actions: true actions taken during rollout,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param old_policy_means: action means for previous policy,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param old_policy_stds: action standard deviation returned by head previously,
|
||||
of shape (batch_size, num_actions) or
|
||||
of shape (batch_size, time_step, num_actions).
|
||||
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
|
||||
:param advantages: change in state value after taking action (a.k.a advantage)
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
old_var = old_policy_stds ** 2
|
||||
# sets diagonal in (batch size and time step) covariance matrices
|
||||
old_covar = mx.nd.eye(N=self.num_actions) * (old_var + eps).broadcast_like(old_policy_means).expand_dims(-2)
|
||||
old_policy_dist = MultivariateNormalDist(self.num_actions, old_policy_means, old_covar, F=F)
|
||||
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
|
||||
|
||||
new_var = new_policy_stds ** 2
|
||||
# sets diagonal in (batch size and time step) covariance matrices
|
||||
new_covar = mx.nd.eye(N=self.num_actions) * (new_var + eps).broadcast_like(new_policy_means).expand_dims(-2)
|
||||
new_policy_dist = MultivariateNormalDist(self.num_actions, new_policy_means, new_covar, F=F)
|
||||
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
|
||||
|
||||
entropy_loss = - self.beta * new_policy_dist.entropy().mean()
|
||||
|
||||
if self.use_kl_regularization:
|
||||
kl_div = old_policy_dist.kl_div(new_policy_dist).mean()
|
||||
weighted_kl_div = kl_coefficient * kl_div
|
||||
high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square()
|
||||
weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div
|
||||
kl_div_loss = weighted_kl_div + weighted_high_kl_div
|
||||
else:
|
||||
kl_div_loss = F.zeros(shape=(1,))
|
||||
|
||||
# working with log probs, so minus first, then exponential (same as division)
|
||||
likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp()
|
||||
|
||||
if self.clip_likelihood_ratio_using_epsilon is not None:
|
||||
# clipping of likelihood ratio
|
||||
min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
|
||||
max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
|
||||
|
||||
# can't use F.clip (with variable clipping bounds), hence custom implementation
|
||||
clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value)
|
||||
|
||||
# lower bound of original, and clipped versions or each scaled advantage
|
||||
# element-wise min between the two ndarrays
|
||||
unclipped_scaled_advantages = likelihood_ratio * advantages
|
||||
clipped_scaled_advantages = clipped_likelihood_ratio * advantages
|
||||
scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0)
|
||||
else:
|
||||
scaled_advantages = likelihood_ratio * advantages
|
||||
clipped_likelihood_ratio = F.zeros_like(likelihood_ratio)
|
||||
|
||||
# for each batch, calculate expectation of scaled_advantages across time steps,
|
||||
# but want code to work with data without time step too, so reshape to add timestep if doesn't exist.
|
||||
scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1))
|
||||
expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1)
|
||||
# want to maximize expected_scaled_advantages, add minus so can minimize.
|
||||
surrogate_loss = (-expected_scaled_advantages * self.weight).mean()
|
||||
|
||||
return [
|
||||
(surrogate_loss, LOSS_OUT_TYPE_LOSS),
|
||||
(entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION),
|
||||
(kl_div_loss, LOSS_OUT_TYPE_KL),
|
||||
(entropy_loss, LOSS_OUT_TYPE_ENTROPY),
|
||||
(likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO),
|
||||
(clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO)
|
||||
]
|
||||
|
||||
|
||||
class PPOHead(Head):
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
network_name: str,
|
||||
head_type_idx: int=0,
|
||||
loss_weight: float=1.,
|
||||
is_local: bool=True,
|
||||
activation_function: str='tanh',
|
||||
dense_layer: None=None) -> None:
|
||||
"""
|
||||
Head block for Proximal Policy Optimization, to calculate probabilities for each action given middleware
|
||||
representation of the environment state.
|
||||
|
||||
:param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon
|
||||
and beta_entropy.
|
||||
:param spaces: containing action spaces used for defining size of network output.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
"""
|
||||
super().__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.return_type = ActionProbabilities
|
||||
|
||||
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
|
||||
self.beta = agent_parameters.algorithm.beta_entropy
|
||||
self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization
|
||||
if self.use_kl_regularization:
|
||||
self.initial_kl_coefficient = agent_parameters.algorithm.initial_kl_coefficient
|
||||
self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence
|
||||
self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient
|
||||
else:
|
||||
self.initial_kl_coefficient, self.kl_cutoff, self.high_kl_penalty_coefficient = (None, None, None)
|
||||
self._loss = []
|
||||
|
||||
if isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
self.net = DiscretePPOHead(num_actions=len(self.spaces.action.actions))
|
||||
elif isinstance(self.spaces.action, BoxActionSpace):
|
||||
self.net = ContinuousPPOHead(num_actions=len(self.spaces.action.actions))
|
||||
else:
|
||||
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
|
||||
|
||||
def hybrid_forward(self,
|
||||
F: ModuleType,
|
||||
x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param x: middleware embedding
|
||||
:return: policy parameters/probabilities
|
||||
"""
|
||||
return self.net(x)
|
||||
|
||||
def loss(self) -> mx.gluon.loss.Loss:
|
||||
"""
|
||||
Specifies loss block to be used for this policy head.
|
||||
|
||||
:return: loss block (can be called as function) for action probabilities returned by this policy network.
|
||||
"""
|
||||
if isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
loss = ClippedPPOLossDiscrete(len(self.spaces.action.actions),
|
||||
self.clip_likelihood_ratio_using_epsilon,
|
||||
self.beta,
|
||||
self.use_kl_regularization, self.initial_kl_coefficient,
|
||||
self.kl_cutoff, self.high_kl_penalty_coefficient,
|
||||
self.loss_weight)
|
||||
elif isinstance(self.spaces.action, BoxActionSpace):
|
||||
loss = ClippedPPOLossContinuous(len(self.spaces.action.actions),
|
||||
self.clip_likelihood_ratio_using_epsilon,
|
||||
self.beta,
|
||||
self.use_kl_regularization, self.initial_kl_coefficient,
|
||||
self.kl_cutoff, self.high_kl_penalty_coefficient,
|
||||
self.loss_weight)
|
||||
else:
|
||||
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
|
||||
loss.initialize()
|
||||
# set a property so can assign_kl_coefficient in future,
|
||||
# make a list, otherwise it would be added as a child of Head Block (due to type check)
|
||||
self._loss = [loss]
|
||||
return loss
|
||||
|
||||
@property
|
||||
def kl_divergence(self):
|
||||
return self.head_type_idx, LOSS_OUT_TYPE_KL
|
||||
|
||||
@property
|
||||
def entropy(self):
|
||||
return self.head_type_idx, LOSS_OUT_TYPE_ENTROPY
|
||||
|
||||
@property
|
||||
def likelihood_ratio(self):
|
||||
return self.head_type_idx, LOSS_OUT_TYPE_LIKELIHOOD_RATIO
|
||||
|
||||
@property
|
||||
def clipped_likelihood_ratio(self):
|
||||
return self.head_type_idx, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO
|
||||
|
||||
def assign_kl_coefficient(self, kl_coefficient: float) -> None:
|
||||
self._loss[0].kl_coefficient.set_data(mx.nd.array((kl_coefficient,)))
|
||||
123
rl_coach/architectures/mxnet_components/heads/ppo_v_head.py
Normal file
123
rl_coach/architectures/mxnet_components/heads/ppo_v_head.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class PPOVHeadLoss(HeadLoss):
|
||||
def __init__(self, clip_likelihood_ratio_using_epsilon: float, weight: float=1, batch_axis: int=0) -> None:
|
||||
"""
|
||||
Loss for PPO Value network.
|
||||
Schulman implemented this extension in OpenAI baselines for PPO2
|
||||
See https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py#L72
|
||||
|
||||
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(PPOVHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
self.weight = weight
|
||||
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['new_policy_values'],
|
||||
agent_inputs=['old_policy_values'],
|
||||
targets=['target_values']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
new_policy_values: nd_sym_type,
|
||||
old_policy_values: nd_sym_type,
|
||||
target_values: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
Calculates two losses (L2 and a clipped difference L2 loss) and takes the maximum of the two.
|
||||
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
|
||||
new_policy_values, old_policy_values and target_values all must include a time_step dimension.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param new_policy_values: values predicted by PPOVHead network,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param old_policy_values: values predicted by old value network,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:param target_values: actual state values,
|
||||
of shape (batch_size) or
|
||||
of shape (batch_size, time_step).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
# L2 loss
|
||||
value_loss_1 = (new_policy_values - target_values).square()
|
||||
# Clipped difference L2 loss
|
||||
diff = new_policy_values - old_policy_values
|
||||
clipped_diff = diff.clip(a_min=-self.clip_likelihood_ratio_using_epsilon,
|
||||
a_max=self.clip_likelihood_ratio_using_epsilon)
|
||||
value_loss_2 = (old_policy_values + clipped_diff - target_values).square()
|
||||
# Maximum of the two losses, element-wise maximum.
|
||||
value_loss_max = mx.nd.stack(value_loss_1, value_loss_2).max(axis=0)
|
||||
# Aggregate over temporal axis, adding if doesn't exist (hense reshape)
|
||||
value_loss_max_w_time = value_loss_max.reshape(shape=(0, -1))
|
||||
value_loss = value_loss_max_w_time.mean(axis=1)
|
||||
# Weight the loss (and average over samples of batch)
|
||||
value_loss_weighted = value_loss.mean() * self.weight
|
||||
return [(value_loss_weighted, LOSS_OUT_TYPE_LOSS)]
|
||||
|
||||
|
||||
class PPOVHead(Head):
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
network_name: str,
|
||||
head_type_idx: int=0,
|
||||
loss_weight: float=1.,
|
||||
is_local: bool = True,
|
||||
activation_function: str='relu',
|
||||
dense_layer: None=None) -> None:
|
||||
"""
|
||||
PPO Value Head for predicting state values.
|
||||
|
||||
:param agent_parameters: containing algorithm parameters, but currently unused.
|
||||
:param spaces: containing action spaces, but currently unused.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
"""
|
||||
super(PPOVHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local,
|
||||
activation_function, dense_layer=dense_layer)
|
||||
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
|
||||
self.return_type = ActionProbabilities
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=1)
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through value head network.
|
||||
|
||||
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final value output of network, of shape (batch_size).
|
||||
"""
|
||||
return self.dense(x).squeeze()
|
||||
|
||||
def loss(self) -> mx.gluon.loss.Loss:
|
||||
"""
|
||||
Specifies loss block to be used for specific value head implementation.
|
||||
|
||||
:return: loss block (can be called as function) for outputs returned by the value head network.
|
||||
"""
|
||||
return PPOVHeadLoss(self.clip_likelihood_ratio_using_epsilon, weight=self.loss_weight)
|
||||
106
rl_coach/architectures/mxnet_components/heads/q_head.py
Normal file
106
rl_coach/architectures/mxnet_components/heads/q_head.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import Union, List, Tuple
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class QHeadLoss(HeadLoss):
|
||||
def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None:
|
||||
"""
|
||||
Loss for Q-Value Head.
|
||||
|
||||
:param loss_type: loss function with default of mean squared error (i.e. L2Loss).
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(QHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
with self.name_scope():
|
||||
self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis)
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['pred'],
|
||||
agent_inputs=[],
|
||||
targets=['target']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
pred: nd_sym_type,
|
||||
target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param pred: state-action q-values predicted by QHead network, of shape (batch_size, num_actions).
|
||||
:param target: actual state-action q-values, of shape (batch_size, num_actions).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
loss = self.loss_fn(pred, target).mean()
|
||||
return [(loss, LOSS_OUT_TYPE_LOSS)]
|
||||
|
||||
|
||||
class QHead(Head):
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
network_name: str,
|
||||
head_type_idx: int=0,
|
||||
loss_weight: float=1.,
|
||||
is_local: bool=True,
|
||||
activation_function: str='relu',
|
||||
dense_layer: None=None,
|
||||
loss_type: Union[HuberLoss, L2Loss]=L2Loss) -> None:
|
||||
"""
|
||||
Q-Value Head for predicting state-action Q-Values.
|
||||
|
||||
:param agent_parameters: containing algorithm parameters, but currently unused.
|
||||
:param spaces: containing action spaces used for defining size of network output.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
:param loss_type: loss function to use.
|
||||
"""
|
||||
super(QHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight,
|
||||
is_local, activation_function, dense_layer)
|
||||
if isinstance(self.spaces.action, BoxActionSpace):
|
||||
self.num_actions = 1
|
||||
elif isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
self.num_actions = len(self.spaces.action.actions)
|
||||
self.return_type = QActionStateValue
|
||||
assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss."
|
||||
self.loss_type = loss_type
|
||||
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=self.num_actions)
|
||||
|
||||
def loss(self) -> Loss:
|
||||
"""
|
||||
Specifies loss block to be used for specific value head implementation.
|
||||
|
||||
:return: loss block (can be called as function) for outputs returned by the head network.
|
||||
"""
|
||||
return QHeadLoss(loss_type=self.loss_type, weight=self.loss_weight)
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through Q-Value head network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: predicted state-action q-values, of shape (batch_size, num_actions).
|
||||
"""
|
||||
return self.dense(x)
|
||||
100
rl_coach/architectures/mxnet_components/heads/v_head.py
Normal file
100
rl_coach/architectures/mxnet_components/heads/v_head.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from typing import Union, List, Tuple
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import VStateValue
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class VHeadLoss(HeadLoss):
|
||||
def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None:
|
||||
"""
|
||||
Loss for Value Head.
|
||||
|
||||
:param loss_type: loss function with default of mean squared error (i.e. L2Loss).
|
||||
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
|
||||
"""
|
||||
super(VHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
|
||||
with self.name_scope():
|
||||
self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis)
|
||||
|
||||
@property
|
||||
def input_schema(self) -> LossInputSchema:
|
||||
return LossInputSchema(
|
||||
head_outputs=['pred'],
|
||||
agent_inputs=[],
|
||||
targets=['target']
|
||||
)
|
||||
|
||||
def loss_forward(self,
|
||||
F: ModuleType,
|
||||
pred: nd_sym_type,
|
||||
target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
|
||||
"""
|
||||
Used for forward pass through loss computations.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param pred: state values predicted by VHead network, of shape (batch_size).
|
||||
:param target: actual state values, of shape (batch_size).
|
||||
:return: loss, of shape (batch_size).
|
||||
"""
|
||||
loss = self.loss_fn(pred, target).mean()
|
||||
return [(loss, LOSS_OUT_TYPE_LOSS)]
|
||||
|
||||
|
||||
class VHead(Head):
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
network_name: str,
|
||||
head_type_idx: int=0,
|
||||
loss_weight: float=1.,
|
||||
is_local: bool=True,
|
||||
activation_function: str='relu',
|
||||
dense_layer: None=None,
|
||||
loss_type: Union[HuberLoss, L2Loss]=L2Loss):
|
||||
"""
|
||||
Value Head for predicting state values.
|
||||
:param agent_parameters: containing algorithm parameters, but currently unused.
|
||||
:param spaces: containing action spaces, but currently unused.
|
||||
:param network_name: name of head network. currently unused.
|
||||
:param head_type_idx: index of head network. currently unused.
|
||||
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
|
||||
:param is_local: flag to denote if network is local. currently unused.
|
||||
:param activation_function: activation function to use between layers. currently unused.
|
||||
:param dense_layer: type of dense layer to use in network. currently unused.
|
||||
:param loss_type: loss function with default of mean squared error (i.e. L2Loss), or alternatively HuberLoss.
|
||||
"""
|
||||
super(VHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight,
|
||||
is_local, activation_function, dense_layer)
|
||||
assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss."
|
||||
self.loss_type = loss_type
|
||||
self.return_type = VStateValue
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=1)
|
||||
|
||||
def loss(self) -> Loss:
|
||||
"""
|
||||
Specifies loss block to be used for specific value head implementation.
|
||||
|
||||
:return: loss block (can be called as function) for outputs returned by the head network.
|
||||
"""
|
||||
return VHeadLoss(loss_type=self.loss_type, weight=self.loss_weight)
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through value head network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: middleware state representation, of shape (batch_size, in_channels).
|
||||
:return: final output of value network, of shape (batch_size).
|
||||
"""
|
||||
return self.dense(x).squeeze()
|
||||
Reference in New Issue
Block a user