1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-16 12:33:34 +02:00

Adding mxnet components to rl_coach/architectures (#60)

Adding mxnet components to rl_coach architectures.

- Supports PPO and DQN
- Tested with CartPole_PPO and CarPole_DQN
- Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset
- Checkpointing is disabled for MXNet
This commit is contained in:
Sina Afrooze
2018-11-07 07:07:15 -08:00
committed by Itai Caspi
parent e7a91b4dc3
commit 5fadb9c18e
39 changed files with 3864 additions and 44 deletions

View File

@@ -0,0 +1,14 @@
from .head import Head, HeadLoss
from .q_head import QHead
from .ppo_head import PPOHead
from .ppo_v_head import PPOVHead
from .v_head import VHead
__all__ = [
'Head',
'HeadLoss',
'QHead',
'PPOHead',
'PPOVHead',
'VHead'
]

View File

@@ -0,0 +1,181 @@
from typing import Dict, List, Union, Tuple
from mxnet.gluon import nn, loss
from mxnet.ndarray import NDArray
from mxnet.symbol import Symbol
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
LOSS_OUT_TYPE_LOSS = 'loss'
LOSS_OUT_TYPE_REGULARIZATION = 'regularization'
class LossInputSchema(object):
"""
Helper class to contain schema for loss hybrid_forward input
"""
def __init__(self, head_outputs: List[str], agent_inputs: List[str], targets: List[str]):
"""
:param head_outputs: list of argument names in hybrid_forward that are outputs of the head.
The order and number MUST MATCH the output from the head.
:param agent_inputs: list of argument names in hybrid_forward that are inputs from the agent.
The order and number MUST MATCH `output_<head_type_idx>_<order>` for this head.
:param targets: list of argument names in hybrid_forward that are targets for the loss.
The order and number MUST MATCH targets passed from the agent.
"""
self._head_outputs = head_outputs
self._agent_inputs = agent_inputs
self._targets = targets
@property
def head_outputs(self):
return self._head_outputs
@property
def agent_inputs(self):
return self._agent_inputs
@property
def targets(self):
return self._targets
class HeadLoss(loss.Loss):
"""
ABC for loss functions of each head. Child class must implement input_schema() and loss_forward()
"""
def __init__(self, *args, **kwargs):
super(HeadLoss, self).__init__(*args, **kwargs)
self._output_schema = None # type: List[str]
@property
def input_schema(self) -> LossInputSchema:
"""
:return: schema for input of hybrid_forward. Read docstring for LossInputSchema for details.
"""
raise NotImplementedError
@property
def output_schema(self) -> List[str]:
"""
:return: schema for output of hybrid_forward. Must contain 'loss' and 'regularization' keys at least once.
The order and total number must match that of returned values from the loss. 'loss' and 'regularization'
are special keys. Any other string is treated as auxiliary outputs and must include match auxiliary
fetch names returned by the head.
"""
return self._output_schema
def forward(self, *args):
"""
Override forward() so that number of outputs can be checked against the schema
"""
outputs = super(HeadLoss, self).forward(*args)
if isinstance(outputs, tuple) or isinstance(outputs, list):
num_outputs = len(outputs)
else:
assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol)
num_outputs = 1
assert num_outputs == len(self.output_schema), "Number of outputs don't match schema ({} != {})".format(
num_outputs, len(self.output_schema))
return outputs
def _loss_output(self, outputs: List[Tuple[Union[NDArray, Symbol], str]]):
"""
Must be called on the output from hybrid_forward().
Saves the returned output as the schema and returns output values in a list
:return: list of output values
"""
output_schema = [o[1] for o in outputs]
assert self._output_schema is None or self._output_schema == output_schema
self._output_schema = output_schema
return tuple(o[0] for o in outputs)
def hybrid_forward(self, F, x, *args, **kwargs):
"""
Passes the cal to loss_forward() and constructs output schema from its output by calling loss_output()
"""
return self._loss_output(self.loss_forward(F, x, *args, **kwargs))
def loss_forward(self, F, x, *args, **kwargs) -> List[Tuple[Union[NDArray, Symbol], str]]:
"""
Similar to hybrid_forward, but returns list of (NDArray, type_str)
"""
raise NotImplementedError
class Head(nn.HybridBlock):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition,
network_name: str, head_type_idx: int=0, loss_weight: float=1., is_local: bool=True,
activation_function: str='relu', dense_layer: None=None):
"""
A head is the final part of the network. It takes the embedding from the middleware embedder and passes it
through a neural network to produce the output of the network. There can be multiple heads in a network, and
each one has an assigned loss function. The heads are algorithm dependent.
:param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon
and beta_entropy.
:param spaces: containing action spaces used for defining size of network output.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
"""
super(Head, self).__init__()
self.head_type_idx = head_type_idx
self.network_name = network_name
self.loss_weight = loss_weight
self.is_local = is_local
self.ap = agent_parameters
self.spaces = spaces
self.return_type = None
self.activation_function = activation_function
self.dense_layer = dense_layer
self._num_outputs = None
def loss(self) -> HeadLoss:
"""
Returns loss block to be used for specific head implementation.
:return: loss block (can be called as function) for outputs returned by the head network.
"""
raise NotImplementedError()
@property
def num_outputs(self):
""" Returns number of outputs that forward() call will return
:return:
"""
assert self._num_outputs is not None, 'must call forward() once to configure number of outputs'
return self._num_outputs
def forward(self, *args):
"""
Override forward() so that number of outputs can be automatically set
"""
outputs = super(Head, self).forward(*args)
if isinstance(outputs, tuple):
num_outputs = len(outputs)
else:
assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol)
num_outputs = 1
if self._num_outputs is None:
self._num_outputs = num_outputs
else:
assert self._num_outputs == num_outputs, 'Number of outputs cannot change ({} != {})'.format(
self._num_outputs, num_outputs)
assert self._num_outputs == len(self.loss().input_schema.head_outputs)
return outputs
def hybrid_forward(self, F, x, *args, **kwargs):
"""
Used for forward pass through head network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final output of network, that will be used in loss calculations.
"""
raise NotImplementedError()

View File

@@ -0,0 +1,669 @@
from typing import List, Tuple, Union
from types import ModuleType
import math
import mxnet as mx
from mxnet.gluon import nn
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
from rl_coach.utils import eps
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
from rl_coach.architectures.mxnet_components.utils import hybrid_clip
LOSS_OUT_TYPE_KL = 'kl_divergence'
LOSS_OUT_TYPE_ENTROPY = 'entropy'
LOSS_OUT_TYPE_LIKELIHOOD_RATIO = 'likelihood_ratio'
LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO = 'clipped_likelihood_ratio'
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class MultivariateNormalDist:
def __init__(self,
num_var: int,
mean: nd_sym_type,
sigma: nd_sym_type,
F: ModuleType=mx.nd) -> None:
"""
Distribution object for Multivariate Normal. Works with batches.
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
mean, sigma and data for log_prob must all include a time_step dimension.
:param num_var: number of variables in distribution
:param mean: mean for each variable,
of shape (num_var) or
of shape (batch_size, num_var) or
of shape (batch_size, time_step, num_var).
:param sigma: covariance matrix,
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
"""
self.num_var = num_var
self.mean = mean
self.sigma = sigma
self.F = F
def inverse_using_cholesky(self, matrix: nd_sym_type) -> nd_sym_type:
"""
Calculate inverses for a batch of matrices using Cholesky decomposition method.
:param matrix: matrix (or matrices) to invert,
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
:return: inverted matrix (or matrices),
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
"""
cholesky_factor = self.F.linalg.potrf(matrix)
return self.F.linalg.potri(cholesky_factor)
def log_det(self, matrix: nd_sym_type) -> nd_sym_type:
"""
Calculate log of the determinant for a batch of matrices using Cholesky decomposition method.
:param matrix: matrix (or matrices) to invert,
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
:return: inverted matrix (or matrices),
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
"""
cholesky_factor = self.F.linalg.potrf(matrix)
return 2 * self.F.linalg.sumlogdiag(cholesky_factor)
def log_prob(self, x: nd_sym_type) -> nd_sym_type:
"""
Calculate the log probability of data given the current distribution.
See http://www.notenoughthoughts.net/posts/normal-log-likelihood-gradient.html
and https://discuss.mxnet.io/t/multivariate-gaussian-log-density-operator/1169/7
:param x: input data,
of shape (num_var) or
of shape (batch_size, num_var) or
of shape (batch_size, time_step, num_var).
:return: log_probability,
of shape (1) or
of shape (batch_size) or
of shape (batch_size, time_step).
"""
a = (self.num_var / 2) * math.log(2 * math.pi)
log_det_sigma = self.log_det(self.sigma)
b = (1 / 2) * log_det_sigma
sigma_inv = self.inverse_using_cholesky(self.sigma)
# deviation from mean, and dev_t is equivalent to transpose on last two dims.
dev = (x - self.mean).expand_dims(-1)
dev_t = (x - self.mean).expand_dims(-2)
# since batch_dot only works with ndarrays with ndim of 3,
# and we could have ndarrays with ndim of 4,
# we flatten batch_size and time_step into single dim.
dev_flat = dev.reshape(shape=(-1, 0, 0), reverse=1)
sigma_inv_flat = sigma_inv.reshape(shape=(-1, 0, 0), reverse=1)
dev_t_flat = dev_t.reshape(shape=(-1, 0, 0), reverse=1)
c = (1 / 2) * self.F.batch_dot(self.F.batch_dot(dev_t_flat, sigma_inv_flat), dev_flat)
# and now reshape back to (batch_size, time_step) if required.
c = c.reshape_like(b)
log_likelihood = -a - b - c
return log_likelihood
def entropy(self) -> nd_sym_type:
"""
Calculate entropy of current distribution.
See http://www.nowozin.net/sebastian/blog/the-entropy-of-a-normal-distribution.html
:return: entropy,
of shape (1) or
of shape (batch_size) or
of shape (batch_size, time_step).
"""
# todo: check if differential entropy is correct
log_det_sigma = self.log_det(self.sigma)
return (self.num_var / 2) + ((self.num_var / 2) * math.log(2 * math.pi)) + ((1 / 2) * log_det_sigma)
def kl_div(self, alt_dist) -> nd_sym_type:
"""
Calculated KL-Divergence with another MultivariateNormalDist distribution
See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
Specifically https://wikimedia.org/api/rest_v1/media/math/render/svg/a3bf3b4917bd1fcb8be48d6d6139e2e387bdc7d3
:param alt_dist: alternative distribution used for kl divergence calculation
:type alt_dist: MultivariateNormalDist
:return: KL-Divergence, of shape (1,)
"""
sigma_a_inv = self.F.linalg.potri(self.F.linalg.potrf(self.sigma))
sigma_b_inv = self.F.linalg.potri(self.F.linalg.potrf(alt_dist.sigma))
term1a = mx.nd.batch_dot(sigma_b_inv, self.sigma)
# sum of diagonal for batch of matrices
term1 = (self.F.eye(self.num_var).broadcast_like(term1a) * term1a).sum(axis=-1).sum(axis=-1)
mean_diff = (alt_dist.mean - self.mean).expand_dims(-1)
mean_diff_t = (alt_dist.mean - self.mean).expand_dims(-2)
term2 = self.F.batch_dot(self.F.batch_dot(mean_diff_t, sigma_b_inv), mean_diff).reshape_like(term1)
term3 = (2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(alt_dist.sigma))) -\
(2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(self.sigma)))
return 0.5 * (term1 + term2 - self.num_var + term3)
class CategoricalDist:
def __init__(self, n_classes: int, probs: nd_sym_type, F: ModuleType=mx.nd) -> None:
"""
Distribution object for Categorical data.
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
mean, sigma and data for log_prob must all include a time_step dimension.
:param n_classes: number of classes in distribution
:param probs: probabilities for each class,
of shape (n_classes),
of shape (batch_size, n_classes) or
of shape (batch_size, time_step, n_classes)
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
"""
self.n_classes = n_classes
self.probs = probs
self.F = F
def log_prob(self, actions: nd_sym_type) -> nd_sym_type:
"""
Calculate the log probability of data given the current distribution.
:param actions: actions, with int8 data type,
of shape (1) if probs was (n_classes),
of shape (batch_size) if probs was (batch_size, n_classes) and
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
:return: log_probability,
of shape (1) if probs was (n_classes),
of shape (batch_size) if probs was (batch_size, n_classes) and
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
"""
action_mask = actions.one_hot(depth=self.n_classes)
action_probs = (self.probs * action_mask).sum(axis=-1)
return action_probs.log()
def entropy(self) -> nd_sym_type:
"""
Calculate entropy of current distribution.
:return: entropy,
of shape (1) if probs was (n_classes),
of shape (batch_size) if probs was (batch_size, n_classes) and
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
"""
# todo: look into numerical stability
return -(self.probs.log()*self.probs).sum(axis=-1)
def kl_div(self, alt_dist) -> nd_sym_type:
"""
Calculated KL-Divergence with another Categorical distribution
:param alt_dist: alternative distribution used for kl divergence calculation
:type alt_dist: CategoricalDist
:return: KL-Divergence
"""
logits_a = self.probs.clip(a_min=eps, a_max=1 - eps).log()
logits_b = alt_dist.probs.clip(a_min=eps, a_max=1 - eps).log()
t = self.probs * (logits_a - logits_b)
t = self.F.where(condition=(alt_dist.probs == 0), x=self.F.ones_like(alt_dist.probs) * math.inf, y=t)
t = self.F.where(condition=(self.probs == 0), x=self.F.zeros_like(self.probs), y=t)
return t.sum(axis=-1)
class DiscretePPOHead(nn.HybridBlock):
def __init__(self, num_actions: int) -> None:
"""
Head block for Discrete Proximal Policy Optimization, to calculate probabilities for each action given
middleware representation of the environment state.
:param num_actions: number of actions in action space.
"""
super(DiscretePPOHead, self).__init__()
with self.name_scope():
self.dense = nn.Dense(units=num_actions, flatten=False)
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
Used for forward pass through head network.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param x: middleware state representation,
of shape (batch_size, in_channels) or
of shape (batch_size, time_step, in_channels).
:return: batch of probabilities for each action,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
"""
policy_values = self.dense(x)
policy_probs = F.softmax(policy_values)
return policy_probs
class ContinuousPPOHead(nn.HybridBlock):
def __init__(self, num_actions: int) -> None:
"""
Head block for Continuous Proximal Policy Optimization, to calculate probabilities for each action given
middleware representation of the environment state.
:param num_actions: number of actions in action space.
"""
super(ContinuousPPOHead, self).__init__()
with self.name_scope():
# todo: change initialization strategy
self.dense = nn.Dense(units=num_actions, flatten=False)
# all samples (across batch, and time step) share the same covariance, which is learnt,
# but since we assume the action probability variables are independent,
# only the diagonal entries of the covariance matrix are specified.
self.log_std = self.params.get('log_std',
shape=num_actions,
init=mx.init.Zero(),
allow_deferred_init=True)
# todo: is_local?
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, log_std: nd_sym_type) -> List[nd_sym_type]:
"""
Used for forward pass through head network.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param x: middleware state representation,
of shape (batch_size, in_channels) or
of shape (batch_size, time_step, in_channels).
:return: batch of probabilities for each action,
of shape (batch_size, action_mean) or
of shape (batch_size, time_step, action_mean).
"""
policy_means = self.dense(x)
policy_std = log_std.exp()
return [policy_means, policy_std]
class ClippedPPOLossDiscrete(HeadLoss):
def __init__(self,
num_actions: int,
clip_likelihood_ratio_using_epsilon: float,
beta: float=0,
use_kl_regularization: bool=False,
initial_kl_coefficient: float=1,
kl_cutoff: float=0,
high_kl_penalty_coefficient: float=1,
weight: float=1,
batch_axis: int=0) -> None:
"""
Loss for discrete version of Clipped PPO.
:param num_actions: number of actions in action space.
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
:param beta: loss coefficient applied to entropy
:param use_kl_regularization: option to add kl divergence loss
:param initial_kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
:param kl_cutoff: threshold for using high_kl_penalty_coefficient
:param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(ClippedPPOLossDiscrete, self).__init__(weight=weight, batch_axis=batch_axis)
self.weight = weight
self.num_actions = num_actions
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
self.beta = beta
self.use_kl_regularization = use_kl_regularization
self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0
self.kl_coefficient = self.params.get('kl_coefficient',
shape=(1,),
init=mx.init.Constant([initial_kl_coefficient,]),
differentiable=False)
self.kl_cutoff = kl_cutoff
self.high_kl_penalty_coefficient = high_kl_penalty_coefficient
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['new_policy_probs'],
agent_inputs=['actions', 'old_policy_probs', 'clip_param_rescaler'],
targets=['advantages']
)
def loss_forward(self,
F: ModuleType,
new_policy_probs: nd_sym_type,
actions: nd_sym_type,
old_policy_probs: nd_sym_type,
clip_param_rescaler: nd_sym_type,
advantages: nd_sym_type,
kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
new_policy_probs, old_policy_probs, actions and advantages all must include a time_step dimension.
NOTE: order of input arguments MUST NOT CHANGE because it matches the order
parameters are passed in ppo_agent:train_network()
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param new_policy_probs: action probabilities predicted by DiscretePPOHead network,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param old_policy_probs: action probabilities for previous policy,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param actions: true actions taken during rollout,
of shape (batch_size) or
of shape (batch_size, time_step).
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
:param advantages: change in state value after taking action (a.k.a advantage)
of shape (batch_size) or
of shape (batch_size, time_step).
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
:return: loss, of shape (batch_size).
"""
old_policy_dist = CategoricalDist(self.num_actions, old_policy_probs, F=F)
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
new_policy_dist = CategoricalDist(self.num_actions, new_policy_probs, F=F)
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
entropy_loss = - self.beta * new_policy_dist.entropy().mean()
if self.use_kl_regularization:
kl_div = old_policy_dist.kl_div(new_policy_dist).mean()
weighted_kl_div = kl_coefficient * kl_div
high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square()
weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div
kl_div_loss = weighted_kl_div + weighted_high_kl_div
else:
kl_div_loss = F.zeros(shape=(1,))
# working with log probs, so minus first, then exponential (same as division)
likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp()
if self.clip_likelihood_ratio_using_epsilon is not None:
# clipping of likelihood ratio
min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
# can't use F.clip (with variable clipping bounds), hence custom implementation
clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value)
# lower bound of original, and clipped versions or each scaled advantage
# element-wise min between the two ndarrays
unclipped_scaled_advantages = likelihood_ratio * advantages
clipped_scaled_advantages = clipped_likelihood_ratio * advantages
scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0)
else:
scaled_advantages = likelihood_ratio * advantages
clipped_likelihood_ratio = F.zeros_like(likelihood_ratio)
# for each batch, calculate expectation of scaled_advantages across time steps,
# but want code to work with data without time step too, so reshape to add timestep if doesn't exist.
scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1))
expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1)
# want to maximize expected_scaled_advantages, add minus so can minimize.
surrogate_loss = (-expected_scaled_advantages * self.weight).mean()
return [
(surrogate_loss, LOSS_OUT_TYPE_LOSS),
(entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION),
(kl_div_loss, LOSS_OUT_TYPE_KL),
(entropy_loss, LOSS_OUT_TYPE_ENTROPY),
(likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO),
(clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO)
]
class ClippedPPOLossContinuous(HeadLoss):
def __init__(self,
num_actions: int,
clip_likelihood_ratio_using_epsilon: float,
beta: float=0,
use_kl_regularization: bool=False,
initial_kl_coefficient: float=1,
kl_cutoff: float=0,
high_kl_penalty_coefficient: float=1,
weight: float=1,
batch_axis: int=0):
"""
Loss for continuous version of Clipped PPO.
:param num_actions: number of actions in action space.
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
:param beta: loss coefficient applied to entropy
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
:param use_kl_regularization: option to add kl divergence loss
:param initial_kl_coefficient: initial loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
:param kl_cutoff: threshold for using high_kl_penalty_coefficient
:param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(ClippedPPOLossContinuous, self).__init__(weight=weight, batch_axis=batch_axis)
self.weight = weight
self.num_actions = num_actions
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
self.beta = beta
self.use_kl_regularization = use_kl_regularization
self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0
self.kl_coefficient = self.params.get('kl_coefficient',
shape=(1,),
init=mx.init.Constant([initial_kl_coefficient,]),
differentiable=False)
self.kl_cutoff = kl_cutoff
self.high_kl_penalty_coefficient = high_kl_penalty_coefficient
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['new_policy_means','new_policy_stds'],
agent_inputs=['actions', 'old_policy_means', 'old_policy_stds', 'clip_param_rescaler'],
targets=['advantages']
)
def loss_forward(self,
F: ModuleType,
new_policy_means: nd_sym_type,
new_policy_stds: nd_sym_type,
actions: nd_sym_type,
old_policy_means: nd_sym_type,
old_policy_stds: nd_sym_type,
clip_param_rescaler: nd_sym_type,
advantages: nd_sym_type,
kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
new_policy_means, old_policy_means, actions and advantages all must include a time_step dimension.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param new_policy_means: action means predicted by MultivariateNormalDist network,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param new_policy_stds: action standard deviation returned by head,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param actions: true actions taken during rollout,
of shape (batch_size) or
of shape (batch_size, time_step).
:param old_policy_means: action means for previous policy,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param old_policy_stds: action standard deviation returned by head previously,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
:param advantages: change in state value after taking action (a.k.a advantage)
of shape (batch_size) or
of shape (batch_size, time_step).
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
:return: loss, of shape (batch_size).
"""
old_var = old_policy_stds ** 2
# sets diagonal in (batch size and time step) covariance matrices
old_covar = mx.nd.eye(N=self.num_actions) * (old_var + eps).broadcast_like(old_policy_means).expand_dims(-2)
old_policy_dist = MultivariateNormalDist(self.num_actions, old_policy_means, old_covar, F=F)
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
new_var = new_policy_stds ** 2
# sets diagonal in (batch size and time step) covariance matrices
new_covar = mx.nd.eye(N=self.num_actions) * (new_var + eps).broadcast_like(new_policy_means).expand_dims(-2)
new_policy_dist = MultivariateNormalDist(self.num_actions, new_policy_means, new_covar, F=F)
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
entropy_loss = - self.beta * new_policy_dist.entropy().mean()
if self.use_kl_regularization:
kl_div = old_policy_dist.kl_div(new_policy_dist).mean()
weighted_kl_div = kl_coefficient * kl_div
high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square()
weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div
kl_div_loss = weighted_kl_div + weighted_high_kl_div
else:
kl_div_loss = F.zeros(shape=(1,))
# working with log probs, so minus first, then exponential (same as division)
likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp()
if self.clip_likelihood_ratio_using_epsilon is not None:
# clipping of likelihood ratio
min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
# can't use F.clip (with variable clipping bounds), hence custom implementation
clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value)
# lower bound of original, and clipped versions or each scaled advantage
# element-wise min between the two ndarrays
unclipped_scaled_advantages = likelihood_ratio * advantages
clipped_scaled_advantages = clipped_likelihood_ratio * advantages
scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0)
else:
scaled_advantages = likelihood_ratio * advantages
clipped_likelihood_ratio = F.zeros_like(likelihood_ratio)
# for each batch, calculate expectation of scaled_advantages across time steps,
# but want code to work with data without time step too, so reshape to add timestep if doesn't exist.
scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1))
expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1)
# want to maximize expected_scaled_advantages, add minus so can minimize.
surrogate_loss = (-expected_scaled_advantages * self.weight).mean()
return [
(surrogate_loss, LOSS_OUT_TYPE_LOSS),
(entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION),
(kl_div_loss, LOSS_OUT_TYPE_KL),
(entropy_loss, LOSS_OUT_TYPE_ENTROPY),
(likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO),
(clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO)
]
class PPOHead(Head):
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
head_type_idx: int=0,
loss_weight: float=1.,
is_local: bool=True,
activation_function: str='tanh',
dense_layer: None=None) -> None:
"""
Head block for Proximal Policy Optimization, to calculate probabilities for each action given middleware
representation of the environment state.
:param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon
and beta_entropy.
:param spaces: containing action spaces used for defining size of network output.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
"""
super().__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.return_type = ActionProbabilities
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.beta = agent_parameters.algorithm.beta_entropy
self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization
if self.use_kl_regularization:
self.initial_kl_coefficient = agent_parameters.algorithm.initial_kl_coefficient
self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence
self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient
else:
self.initial_kl_coefficient, self.kl_cutoff, self.high_kl_penalty_coefficient = (None, None, None)
self._loss = []
if isinstance(self.spaces.action, DiscreteActionSpace):
self.net = DiscretePPOHead(num_actions=len(self.spaces.action.actions))
elif isinstance(self.spaces.action, BoxActionSpace):
self.net = ContinuousPPOHead(num_actions=len(self.spaces.action.actions))
else:
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
def hybrid_forward(self,
F: ModuleType,
x: nd_sym_type) -> nd_sym_type:
"""
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param x: middleware embedding
:return: policy parameters/probabilities
"""
return self.net(x)
def loss(self) -> mx.gluon.loss.Loss:
"""
Specifies loss block to be used for this policy head.
:return: loss block (can be called as function) for action probabilities returned by this policy network.
"""
if isinstance(self.spaces.action, DiscreteActionSpace):
loss = ClippedPPOLossDiscrete(len(self.spaces.action.actions),
self.clip_likelihood_ratio_using_epsilon,
self.beta,
self.use_kl_regularization, self.initial_kl_coefficient,
self.kl_cutoff, self.high_kl_penalty_coefficient,
self.loss_weight)
elif isinstance(self.spaces.action, BoxActionSpace):
loss = ClippedPPOLossContinuous(len(self.spaces.action.actions),
self.clip_likelihood_ratio_using_epsilon,
self.beta,
self.use_kl_regularization, self.initial_kl_coefficient,
self.kl_cutoff, self.high_kl_penalty_coefficient,
self.loss_weight)
else:
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
loss.initialize()
# set a property so can assign_kl_coefficient in future,
# make a list, otherwise it would be added as a child of Head Block (due to type check)
self._loss = [loss]
return loss
@property
def kl_divergence(self):
return self.head_type_idx, LOSS_OUT_TYPE_KL
@property
def entropy(self):
return self.head_type_idx, LOSS_OUT_TYPE_ENTROPY
@property
def likelihood_ratio(self):
return self.head_type_idx, LOSS_OUT_TYPE_LIKELIHOOD_RATIO
@property
def clipped_likelihood_ratio(self):
return self.head_type_idx, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO
def assign_kl_coefficient(self, kl_coefficient: float) -> None:
self._loss[0].kl_coefficient.set_data(mx.nd.array((kl_coefficient,)))

View File

@@ -0,0 +1,123 @@
from typing import List, Tuple, Union
from types import ModuleType
import mxnet as mx
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
from rl_coach.spaces import SpacesDefinition
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class PPOVHeadLoss(HeadLoss):
def __init__(self, clip_likelihood_ratio_using_epsilon: float, weight: float=1, batch_axis: int=0) -> None:
"""
Loss for PPO Value network.
Schulman implemented this extension in OpenAI baselines for PPO2
See https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py#L72
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(PPOVHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
self.weight = weight
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['new_policy_values'],
agent_inputs=['old_policy_values'],
targets=['target_values']
)
def loss_forward(self,
F: ModuleType,
new_policy_values: nd_sym_type,
old_policy_values: nd_sym_type,
target_values: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
Calculates two losses (L2 and a clipped difference L2 loss) and takes the maximum of the two.
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
new_policy_values, old_policy_values and target_values all must include a time_step dimension.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param new_policy_values: values predicted by PPOVHead network,
of shape (batch_size) or
of shape (batch_size, time_step).
:param old_policy_values: values predicted by old value network,
of shape (batch_size) or
of shape (batch_size, time_step).
:param target_values: actual state values,
of shape (batch_size) or
of shape (batch_size, time_step).
:return: loss, of shape (batch_size).
"""
# L2 loss
value_loss_1 = (new_policy_values - target_values).square()
# Clipped difference L2 loss
diff = new_policy_values - old_policy_values
clipped_diff = diff.clip(a_min=-self.clip_likelihood_ratio_using_epsilon,
a_max=self.clip_likelihood_ratio_using_epsilon)
value_loss_2 = (old_policy_values + clipped_diff - target_values).square()
# Maximum of the two losses, element-wise maximum.
value_loss_max = mx.nd.stack(value_loss_1, value_loss_2).max(axis=0)
# Aggregate over temporal axis, adding if doesn't exist (hense reshape)
value_loss_max_w_time = value_loss_max.reshape(shape=(0, -1))
value_loss = value_loss_max_w_time.mean(axis=1)
# Weight the loss (and average over samples of batch)
value_loss_weighted = value_loss.mean() * self.weight
return [(value_loss_weighted, LOSS_OUT_TYPE_LOSS)]
class PPOVHead(Head):
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
head_type_idx: int=0,
loss_weight: float=1.,
is_local: bool = True,
activation_function: str='relu',
dense_layer: None=None) -> None:
"""
PPO Value Head for predicting state values.
:param agent_parameters: containing algorithm parameters, but currently unused.
:param spaces: containing action spaces, but currently unused.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
"""
super(PPOVHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local,
activation_function, dense_layer=dense_layer)
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.return_type = ActionProbabilities
with self.name_scope():
self.dense = nn.Dense(units=1)
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
Used for forward pass through value head network.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final value output of network, of shape (batch_size).
"""
return self.dense(x).squeeze()
def loss(self) -> mx.gluon.loss.Loss:
"""
Specifies loss block to be used for specific value head implementation.
:return: loss block (can be called as function) for outputs returned by the value head network.
"""
return PPOVHeadLoss(self.clip_likelihood_ratio_using_epsilon, weight=self.loss_weight)

View File

@@ -0,0 +1,106 @@
from typing import Union, List, Tuple
from types import ModuleType
import mxnet as mx
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class QHeadLoss(HeadLoss):
def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None:
"""
Loss for Q-Value Head.
:param loss_type: loss function with default of mean squared error (i.e. L2Loss).
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(QHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
with self.name_scope():
self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis)
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['pred'],
agent_inputs=[],
targets=['target']
)
def loss_forward(self,
F: ModuleType,
pred: nd_sym_type,
target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param pred: state-action q-values predicted by QHead network, of shape (batch_size, num_actions).
:param target: actual state-action q-values, of shape (batch_size, num_actions).
:return: loss, of shape (batch_size).
"""
loss = self.loss_fn(pred, target).mean()
return [(loss, LOSS_OUT_TYPE_LOSS)]
class QHead(Head):
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
head_type_idx: int=0,
loss_weight: float=1.,
is_local: bool=True,
activation_function: str='relu',
dense_layer: None=None,
loss_type: Union[HuberLoss, L2Loss]=L2Loss) -> None:
"""
Q-Value Head for predicting state-action Q-Values.
:param agent_parameters: containing algorithm parameters, but currently unused.
:param spaces: containing action spaces used for defining size of network output.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
:param loss_type: loss function to use.
"""
super(QHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight,
is_local, activation_function, dense_layer)
if isinstance(self.spaces.action, BoxActionSpace):
self.num_actions = 1
elif isinstance(self.spaces.action, DiscreteActionSpace):
self.num_actions = len(self.spaces.action.actions)
self.return_type = QActionStateValue
assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss."
self.loss_type = loss_type
with self.name_scope():
self.dense = nn.Dense(units=self.num_actions)
def loss(self) -> Loss:
"""
Specifies loss block to be used for specific value head implementation.
:return: loss block (can be called as function) for outputs returned by the head network.
"""
return QHeadLoss(loss_type=self.loss_type, weight=self.loss_weight)
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
Used for forward pass through Q-Value head network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: predicted state-action q-values, of shape (batch_size, num_actions).
"""
return self.dense(x)

View File

@@ -0,0 +1,100 @@
from typing import Union, List, Tuple
from types import ModuleType
import mxnet as mx
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import VStateValue
from rl_coach.spaces import SpacesDefinition
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class VHeadLoss(HeadLoss):
def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None:
"""
Loss for Value Head.
:param loss_type: loss function with default of mean squared error (i.e. L2Loss).
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(VHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
with self.name_scope():
self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis)
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['pred'],
agent_inputs=[],
targets=['target']
)
def loss_forward(self,
F: ModuleType,
pred: nd_sym_type,
target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param pred: state values predicted by VHead network, of shape (batch_size).
:param target: actual state values, of shape (batch_size).
:return: loss, of shape (batch_size).
"""
loss = self.loss_fn(pred, target).mean()
return [(loss, LOSS_OUT_TYPE_LOSS)]
class VHead(Head):
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
head_type_idx: int=0,
loss_weight: float=1.,
is_local: bool=True,
activation_function: str='relu',
dense_layer: None=None,
loss_type: Union[HuberLoss, L2Loss]=L2Loss):
"""
Value Head for predicting state values.
:param agent_parameters: containing algorithm parameters, but currently unused.
:param spaces: containing action spaces, but currently unused.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
:param loss_type: loss function with default of mean squared error (i.e. L2Loss), or alternatively HuberLoss.
"""
super(VHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight,
is_local, activation_function, dense_layer)
assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss."
self.loss_type = loss_type
self.return_type = VStateValue
with self.name_scope():
self.dense = nn.Dense(units=1)
def loss(self) -> Loss:
"""
Specifies loss block to be used for specific value head implementation.
:return: loss block (can be called as function) for outputs returned by the head network.
"""
return VHeadLoss(loss_type=self.loss_type, weight=self.loss_weight)
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
Used for forward pass through value head network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final output of value network, of shape (batch_size).
"""
return self.dense(x).squeeze()