mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
@@ -41,22 +41,41 @@ class Architecture(object):
|
||||
self.optimizer = None
|
||||
self.ap = agent_parameters
|
||||
|
||||
def predict(self, inputs: Dict[str, np.ndarray]) -> List[np.ndarray]:
|
||||
def predict(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
outputs: List[Any] = None,
|
||||
squeeze_output: bool = True,
|
||||
initial_feed_dict: Dict[Any, np.ndarray] = None) -> Tuple[np.ndarray, ...]:
|
||||
"""
|
||||
Given input observations, use the model to make predictions (e.g. action or value).
|
||||
|
||||
:param inputs: current state (i.e. observations, measurements, goals, etc.)
|
||||
(e.g. `{'observation': numpy.ndarray}` of shape (batch_size, observation_space_size))
|
||||
:param outputs: list of outputs to return. Return all outputs if unspecified. Type of the list elements
|
||||
depends on the framework backend.
|
||||
:param squeeze_output: call squeeze_list on output before returning if True
|
||||
:param initial_feed_dict: a dictionary of extra inputs for forward pass.
|
||||
:return: predictions of action or value of shape (batch_size, action_space_size) for action predictions)
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def parallel_predict(sess: Any,
|
||||
network_input_tuples: List[Tuple['Architecture', Dict[str, np.ndarray]]]) -> \
|
||||
Tuple[np.ndarray, ...]:
|
||||
"""
|
||||
:param sess: active session to use for prediction
|
||||
:param network_input_tuples: tuple of network and corresponding input
|
||||
:return: list or tuple of outputs from all networks
|
||||
"""
|
||||
pass
|
||||
|
||||
def train_on_batch(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
targets: List[np.ndarray],
|
||||
scaler: float=1.,
|
||||
additional_fetches: list=None,
|
||||
importance_weights: np.ndarray=None) -> tuple:
|
||||
importance_weights: np.ndarray=None) -> Tuple[float, List[float], float, list]:
|
||||
"""
|
||||
Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a
|
||||
forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to
|
||||
@@ -118,8 +137,7 @@ class Architecture(object):
|
||||
targets: List[np.ndarray],
|
||||
additional_fetches: list=None,
|
||||
importance_weights: np.ndarray=None,
|
||||
no_accumulation: bool=False) ->\
|
||||
Tuple[float, List[float], float, list]:
|
||||
no_accumulation: bool=False) -> Tuple[float, List[float], float, list]:
|
||||
"""
|
||||
Given a batch of inputs (i.e. states) and targets (e.g. discounted rewards), computes and accumulates the
|
||||
gradients for model parameters. Will run forward and backward pass to compute gradients, clip the gradient
|
||||
@@ -142,30 +160,33 @@ class Architecture(object):
|
||||
calculated gradients
|
||||
:return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors
|
||||
total_loss (float): sum of all head losses
|
||||
losses (list of float): list of all losses. The order is list of target losses followed by list of regularization losses.
|
||||
The specifics of losses is dependant on the network parameters (number of heads, etc.)
|
||||
losses (list of float): list of all losses. The order is list of target losses followed by list of
|
||||
regularization losses. The specifics of losses is dependant on the network parameters
|
||||
(number of heads, etc.)
|
||||
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
|
||||
fetched_tensors: all values for additional_fetches
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply_and_reset_gradients(self, gradients: List[np.ndarray]) -> None:
|
||||
def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
||||
"""
|
||||
Applies the given gradients to the network weights and resets the gradient accumulations.
|
||||
Has the same impact as calling `apply_gradients`, then `reset_accumulated_gradients`.
|
||||
|
||||
:param gradients: gradients for the parameter weights, taken from `accumulated_gradients` property
|
||||
of an identical network (either self or another identical network)
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply_gradients(self, gradients: List[np.ndarray]) -> None:
|
||||
def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
||||
"""
|
||||
Applies the given gradients to the network weights.
|
||||
Will be performed sync or async depending on `network_parameters.async_training`
|
||||
|
||||
:param gradients: gradients for the parameter weights, taken from `accumulated_gradients` property
|
||||
of an identical network (either self or another identical network)
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user