1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
Files
coach/rl_coach/architectures/mxnet_components/heads/v_head.py
Sina Afrooze 67eb9e4c28 Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
2018-11-19 19:45:49 +02:00

102 lines
4.5 KiB
Python

from typing import Union, List, Tuple
from types import ModuleType
import mxnet as mx
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
NormalizedRSSInitializer
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import VStateValue
from rl_coach.spaces import SpacesDefinition
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class VHeadLoss(HeadLoss):
def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None:
"""
Loss for Value Head.
:param loss_type: loss function with default of mean squared error (i.e. L2Loss).
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(VHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
with self.name_scope():
self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis)
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['pred'],
agent_inputs=[],
targets=['target']
)
def loss_forward(self,
F: ModuleType,
pred: nd_sym_type,
target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param pred: state values predicted by VHead network, of shape (batch_size).
:param target: actual state values, of shape (batch_size).
:return: loss, of shape (batch_size).
"""
loss = self.loss_fn(pred, target).mean()
return [(loss, LOSS_OUT_TYPE_LOSS)]
class VHead(Head):
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
head_type_idx: int=0,
loss_weight: float=1.,
is_local: bool=True,
activation_function: str='relu',
dense_layer: None=None,
loss_type: Union[HuberLoss, L2Loss]=L2Loss):
"""
Value Head for predicting state values.
:param agent_parameters: containing algorithm parameters, but currently unused.
:param spaces: containing action spaces, but currently unused.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
:param loss_type: loss function with default of mean squared error (i.e. L2Loss), or alternatively HuberLoss.
"""
super(VHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight,
is_local, activation_function, dense_layer)
assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss."
self.loss_type = loss_type
self.return_type = VStateValue
with self.name_scope():
self.dense = nn.Dense(units=1, weight_initializer=NormalizedRSSInitializer(1.0))
def loss(self) -> Loss:
"""
Specifies loss block to be used for specific value head implementation.
:return: loss block (can be called as function) for outputs returned by the head network.
"""
return VHeadLoss(loss_type=self.loss_type, weight=self.loss_weight)
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
Used for forward pass through value head network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final output of value network, of shape (batch_size).
"""
return self.dense(x).squeeze(axis=1)