mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Enabling-more-agents-for-Batch-RL-and-cleanup (#258)
allowing for the last training batch drawn to be smaller than batch_size + adding support for more agents in BatchRL by adding softmax with temperature to the corresponding heads + adding a CartPole_QR_DQN preset with a golden test + cleanups
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .q_head import QHead
|
||||
from .categorical_q_head import CategoricalQHead
|
||||
from .ddpg_actor_head import DDPGActor
|
||||
from .dnd_q_head import DNDQHead
|
||||
@@ -7,7 +8,6 @@ from .naf_head import NAFHead
|
||||
from .policy_head import PolicyHead
|
||||
from .ppo_head import PPOHead
|
||||
from .ppo_v_head import PPOVHead
|
||||
from .q_head import QHead
|
||||
from .quantile_regression_q_head import QuantileRegressionQHead
|
||||
from .rainbow_q_head import RainbowQHead
|
||||
from .v_head import VHead
|
||||
|
||||
@@ -15,16 +15,15 @@
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.architectures.tensorflow_components.heads import QHead
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class CategoricalQHead(Head):
|
||||
class CategoricalQHead(QHead):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu',
|
||||
dense_layer=Dense):
|
||||
@@ -33,7 +32,9 @@ class CategoricalQHead(Head):
|
||||
self.name = 'categorical_dqn_head'
|
||||
self.num_actions = len(self.spaces.action.actions)
|
||||
self.num_atoms = agent_parameters.algorithm.atoms
|
||||
self.return_type = QActionStateValue
|
||||
self.z_values = tf.cast(tf.constant(np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max,
|
||||
self.ap.algorithm.atoms), dtype=tf.float32), dtype=tf.float64)
|
||||
self.loss_type = []
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
values_distribution = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='output')
|
||||
@@ -49,6 +50,11 @@ class CategoricalQHead(Head):
|
||||
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.z_values, 1)
|
||||
|
||||
# used in batch-rl to estimate a probablity distribution over actions
|
||||
self.softmax = self.add_softmax_with_temperature()
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"Dense (num outputs = {})".format(self.num_actions * self.num_atoms),
|
||||
|
||||
@@ -56,11 +56,14 @@ class DNDQHead(QHead):
|
||||
|
||||
# Retrieve info from DND dictionary
|
||||
# We assume that all actions have enough entries in the DND
|
||||
self.output = tf.transpose([
|
||||
self.q_values = self.output = tf.transpose([
|
||||
self._q_value(input_layer, action)
|
||||
for action in range(self.num_actions)
|
||||
])
|
||||
|
||||
# used in batch-rl to estimate a probablity distribution over actions
|
||||
self.softmax = self.add_softmax_with_temperature()
|
||||
|
||||
def _q_value(self, input_layer, action):
|
||||
result = tf.py_func(self.DND.query,
|
||||
[input_layer, action, self.number_of_nn],
|
||||
|
||||
@@ -44,7 +44,10 @@ class DuelingQHead(QHead):
|
||||
self.action_advantage = self.action_advantage - self.action_mean
|
||||
|
||||
# merge to state-action value function Q
|
||||
self.output = tf.add(self.state_value, self.action_advantage, name='output')
|
||||
self.q_values = self.output = tf.add(self.state_value, self.action_advantage, name='output')
|
||||
|
||||
# used in batch-rl to estimate a probablity distribution over actions
|
||||
self.softmax = self.add_softmax_with_temperature()
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
|
||||
@@ -48,15 +48,19 @@ class QHead(Head):
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# Standard Q Network
|
||||
self.output = self.dense_layer(self.num_actions)(input_layer, name='output')
|
||||
self.q_values = self.output = self.dense_layer(self.num_actions)(input_layer, name='output')
|
||||
|
||||
# TODO add this to other Q heads. e.g. dueling.
|
||||
temperature = self.ap.network_wrappers[self.network_name].softmax_temperature
|
||||
temperature_scaled_outputs = self.output / temperature
|
||||
self.softmax = tf.nn.softmax(temperature_scaled_outputs, name="softmax")
|
||||
# used in batch-rl to estimate a probablity distribution over actions
|
||||
self.softmax = self.add_softmax_with_temperature()
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"Dense (num outputs = {})".format(self.num_actions)
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
def add_softmax_with_temperature(self):
|
||||
temperature = self.ap.network_wrappers[self.network_name].softmax_temperature
|
||||
temperature_scaled_outputs = self.q_values / temperature
|
||||
return tf.nn.softmax(temperature_scaled_outputs, name="softmax")
|
||||
|
||||
|
||||
@@ -15,15 +15,14 @@
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.architectures.tensorflow_components.heads import QHead
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class QuantileRegressionQHead(Head):
|
||||
class QuantileRegressionQHead(QHead):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
@@ -33,7 +32,10 @@ class QuantileRegressionQHead(Head):
|
||||
self.num_actions = len(self.spaces.action.actions)
|
||||
self.num_atoms = agent_parameters.algorithm.atoms # we use atom / quantile interchangeably
|
||||
self.huber_loss_interval = agent_parameters.algorithm.huber_loss_interval # k
|
||||
self.return_type = QActionStateValue
|
||||
self.quantile_probabilities = tf.cast(
|
||||
tf.constant(np.ones(self.ap.algorithm.atoms) / float(self.ap.algorithm.atoms), dtype=tf.float32),
|
||||
dtype=tf.float64)
|
||||
self.loss_type = []
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
self.actions = tf.placeholder(tf.int32, [None, 2], name="actions")
|
||||
@@ -72,6 +74,11 @@ class QuantileRegressionQHead(Head):
|
||||
self.loss = quantile_regression_loss
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.quantile_probabilities, 1)
|
||||
|
||||
# used in batch-rl to estimate a probablity distribution over actions
|
||||
self.softmax = self.add_softmax_with_temperature()
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"Dense (num outputs = {})".format(self.num_actions * self.num_atoms),
|
||||
|
||||
@@ -15,15 +15,14 @@
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.architectures.tensorflow_components.heads import QHead
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class RainbowQHead(Head):
|
||||
class RainbowQHead(QHead):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
@@ -31,8 +30,10 @@ class RainbowQHead(Head):
|
||||
dense_layer=dense_layer)
|
||||
self.num_actions = len(self.spaces.action.actions)
|
||||
self.num_atoms = agent_parameters.algorithm.atoms
|
||||
self.return_type = QActionStateValue
|
||||
self.name = 'rainbow_q_values_head'
|
||||
self.z_values = tf.cast(tf.constant(np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max,
|
||||
self.ap.algorithm.atoms), dtype=tf.float32), dtype=tf.float64)
|
||||
self.loss_type = []
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# state value tower - V
|
||||
@@ -63,6 +64,11 @@ class RainbowQHead(Head):
|
||||
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.z_values, 1)
|
||||
|
||||
# used in batch-rl to estimate a probablity distribution over actions
|
||||
self.softmax = self.add_softmax_with_temperature()
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"State Value Stream - V",
|
||||
|
||||
Reference in New Issue
Block a user