1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Enabling-more-agents-for-Batch-RL-and-cleanup (#258)

allowing for the last training batch drawn to be smaller than batch_size + adding support for more agents in BatchRL by adding softmax with temperature to the corresponding heads + adding a CartPole_QR_DQN preset with a golden test + cleanups
This commit is contained in:
Gal Leibovich
2019-03-21 16:10:29 +02:00
committed by GitHub
parent abec59f367
commit 6e08c55ad5
24 changed files with 152 additions and 69 deletions

View File

@@ -18,11 +18,9 @@ from typing import Union
import numpy as np import numpy as np
from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAgentParameters
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.base_parameters import AgentParameters
from rl_coach.exploration_policies.bootstrapped import BootstrappedParameters from rl_coach.exploration_policies.bootstrapped import BootstrappedParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
class BootstrappedDQNNetworkParameters(DQNNetworkParameters): class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
@@ -32,12 +30,11 @@ class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
self.heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/self.heads_parameters[0].num_output_head_copies self.heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/self.heads_parameters[0].num_output_head_copies
class BootstrappedDQNAgentParameters(AgentParameters): class BootstrappedDQNAgentParameters(DQNAgentParameters):
def __init__(self): def __init__(self):
super().__init__(algorithm=DQNAlgorithmParameters(), super().__init__()
exploration=BootstrappedParameters(), self.exploration = BootstrappedParameters()
memory=ExperienceReplayParameters(), self.network_wrappers = {"main": BootstrappedDQNNetworkParameters()}
networks={"main": BootstrappedDQNNetworkParameters()})
@property @property
def path(self): def path(self):
@@ -65,13 +62,14 @@ class BootstrappedDQNAgent(ValueOptimizationAgent):
TD_targets = result[self.ap.exploration.architecture_num_q_heads:] TD_targets = result[self.ap.exploration.architecture_num_q_heads:]
# add Q value samples for logging # add Q value samples for logging
self.q_values.add_sample(TD_targets)
# initialize with the current prediction so that we will # initialize with the current prediction so that we will
# only update the action that we have actually done in this transition # only update the action that we have actually done in this transition
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(batch.size):
mask = batch[i].info['mask'] mask = batch[i].info['mask']
for head_idx in range(self.ap.exploration.architecture_num_q_heads): for head_idx in range(self.ap.exploration.architecture_num_q_heads):
self.q_values.add_sample(TD_targets[head_idx])
if mask[head_idx] == 1: if mask[head_idx] == 1:
selected_action = np.argmax(next_states_online_values[head_idx][i], 0) selected_action = np.argmax(next_states_online_values[head_idx][i], 0)
TD_targets[head_idx][i, batch.actions()[i]] = \ TD_targets[head_idx][i, batch.actions()[i]] = \

View File

@@ -84,8 +84,8 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# prediction's format is (batch,actions,atoms) # prediction's format is (batch,actions,atoms)
def get_all_q_values_for_states(self, states: StateType): def get_all_q_values_for_states(self, states: StateType):
if self.exploration_policy.requires_action_values(): if self.exploration_policy.requires_action_values():
prediction = self.get_prediction(states) q_values = self.get_prediction(states,
q_values = self.distribution_prediction_to_q_values(prediction) outputs=self.networks['main'].online_network.output_heads[0].q_values)
else: else:
q_values = None q_values = None
return q_values return q_values
@@ -105,9 +105,9 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# select the optimal actions for the next state # select the optimal actions for the next state
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1) target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) m = np.zeros((batch.size, self.z_values.size))
batches = np.arange(self.ap.network_wrappers['main'].batch_size) batches = np.arange(batch.size)
# an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping. # an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.
# only 10% speedup overall - leaving commented out as the code is not as clear. # only 10% speedup overall - leaving commented out as the code is not as clear.
@@ -120,7 +120,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0]) # bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])
# u_ = (np.ceil(bj_)).astype(int) # u_ = (np.ceil(bj_)).astype(int)
# l_ = (np.floor(bj_)).astype(int) # l_ = (np.floor(bj_)).astype(int)
# m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) # m_ = np.zeros((batch.size, self.z_values.size))
# np.add.at(m_, [batches, l_], # np.add.at(m_, [batches, l_],
# np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_)) # np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))
# np.add.at(m_, [batches, u_], # np.add.at(m_, [batches, u_],

View File

@@ -207,6 +207,8 @@ class ClippedPPOAgent(ActorCriticAgent):
self.networks['main'].online_network.output_heads[1].likelihood_ratio, self.networks['main'].online_network.output_heads[1].likelihood_ratio,
self.networks['main'].online_network.output_heads[1].clipped_likelihood_ratio] self.networks['main'].online_network.output_heads[1].clipped_likelihood_ratio]
# TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on
# some of the data
for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size)): for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size)):
start = i * self.ap.network_wrappers['main'].batch_size start = i * self.ap.network_wrappers['main'].batch_size
end = (i + 1) * self.ap.network_wrappers['main'].batch_size end = (i + 1) * self.ap.network_wrappers['main'].batch_size

View File

@@ -56,7 +56,7 @@ class DDQNAgent(ValueOptimizationAgent):
# initialize with the current prediction so that we will # initialize with the current prediction so that we will
# only update the action that we have actually done in this transition # only update the action that we have actually done in this transition
TD_errors = [] TD_errors = []
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(batch.size):
new_target = batch.rewards()[i] + \ new_target = batch.rewards()[i] + \
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * q_st_plus_1[i][selected_actions[i]] (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * q_st_plus_1[i][selected_actions[i]]
TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]])) TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]]))

View File

@@ -146,13 +146,13 @@ class DFPAgent(Agent):
network_inputs = batch.states(network_keys) network_inputs = batch.states(network_keys)
network_inputs['goal'] = np.repeat(np.expand_dims(self.current_goal, 0), network_inputs['goal'] = np.repeat(np.expand_dims(self.current_goal, 0),
self.ap.network_wrappers['main'].batch_size, axis=0) batch.size, axis=0)
# get the current outputs of the network # get the current outputs of the network
targets = self.networks['main'].online_network.predict(network_inputs) targets = self.networks['main'].online_network.predict(network_inputs)
# change the targets for the taken actions # change the targets for the taken actions
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(batch.size):
targets[i, batch.actions()[i]] = batch[i].info['future_measurements'].flatten() targets[i, batch.actions()[i]] = batch[i].info['future_measurements'].flatten()
result = self.networks['main'].train_and_sync_networks(network_inputs, targets) result = self.networks['main'].train_and_sync_networks(network_inputs, targets)

View File

@@ -86,7 +86,7 @@ class DQNAgent(ValueOptimizationAgent):
# only update the action that we have actually done in this transition # only update the action that we have actually done in this transition
TD_errors = [] TD_errors = []
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(batch.size):
new_target = batch.rewards()[i] +\ new_target = batch.rewards()[i] +\
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * np.max(q_st_plus_1[i], 0) (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * np.max(q_st_plus_1[i], 0)
TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]])) TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]]))

View File

@@ -65,7 +65,7 @@ class MixedMonteCarloAgent(ValueOptimizationAgent):
total_returns = batch.n_step_discounted_rewards() total_returns = batch.n_step_discounted_rewards()
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(batch.size):
one_step_target = batch.rewards()[i] + \ one_step_target = batch.rewards()[i] + \
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
q_st_plus_1[i][selected_actions[i]] q_st_plus_1[i][selected_actions[i]]

View File

@@ -133,7 +133,7 @@ class NECAgent(ValueOptimizationAgent):
TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys)) TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys))
bootstrapped_return_from_old_policy = batch.n_step_discounted_rewards() bootstrapped_return_from_old_policy = batch.n_step_discounted_rewards()
# only update the action that we have actually done in this transition # only update the action that we have actually done in this transition
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(batch.size):
TD_targets[i, batch.actions()[i]] = bootstrapped_return_from_old_policy[i] TD_targets[i, batch.actions()[i]] = bootstrapped_return_from_old_policy[i]
# set the gradients to fetch for the DND update # set the gradients to fetch for the DND update

View File

@@ -84,7 +84,7 @@ class PALAgent(ValueOptimizationAgent):
# calculate TD error # calculate TD error
TD_targets = np.copy(q_st_online) TD_targets = np.copy(q_st_online)
total_returns = batch.n_step_discounted_rewards() total_returns = batch.n_step_discounted_rewards()
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(batch.size):
TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \ TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
q_st_plus_1_target[i][selected_actions[i]] q_st_plus_1_target[i][selected_actions[i]]

View File

@@ -95,7 +95,7 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1) target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1)
# calculate the Bellman update # calculate the Bellman update
batch_idx = list(range(self.ap.network_wrappers['main'].batch_size)) batch_idx = list(range(batch.size))
TD_targets = batch.rewards(True) + (1.0 - batch.game_overs(True)) * self.ap.algorithm.discount \ TD_targets = batch.rewards(True) + (1.0 - batch.game_overs(True)) * self.ap.algorithm.discount \
* next_state_quantiles[batch_idx, target_actions] * next_state_quantiles[batch_idx, target_actions]
@@ -106,9 +106,9 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
# calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order # calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order
cumulative_probabilities = np.array(range(self.ap.algorithm.atoms + 1)) / float(self.ap.algorithm.atoms) # tau_i cumulative_probabilities = np.array(range(self.ap.algorithm.atoms + 1)) / float(self.ap.algorithm.atoms) # tau_i
quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i
quantile_midpoints = np.tile(quantile_midpoints, (self.ap.network_wrappers['main'].batch_size, 1)) quantile_midpoints = np.tile(quantile_midpoints, (batch.size, 1))
sorted_quantiles = np.argsort(current_quantiles[batch_idx, batch.actions()]) sorted_quantiles = np.argsort(current_quantiles[batch_idx, batch.actions()])
for idx in range(self.ap.network_wrappers['main'].batch_size): for idx in range(batch.size):
quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]] quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]]
# train # train

View File

@@ -103,9 +103,9 @@ class RainbowDQNAgent(CategoricalDQNAgent):
# only update the action that we have actually done in this transition (using the Double-DQN selected actions) # only update the action that we have actually done in this transition (using the Double-DQN selected actions)
target_actions = ddqn_selected_actions target_actions = ddqn_selected_actions
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) m = np.zeros((batch.size, self.z_values.size))
batches = np.arange(self.ap.network_wrappers['main'].batch_size) batches = np.arange(batch.size)
for j in range(self.z_values.size): for j in range(self.z_values.size):
# we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step, # we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step,
# we will not bootstrap for the last n-step transitions in the episode # we will not bootstrap for the last n-step transitions in the episode

View File

@@ -50,8 +50,9 @@ class ValueOptimizationAgent(Agent):
actions_q_values = None actions_q_values = None
return actions_q_values return actions_q_values
def get_prediction(self, states): def get_prediction(self, states, outputs=None):
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main')) return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'),
outputs=outputs)
def update_transition_priorities_and_get_weights(self, TD_errors, batch): def update_transition_priorities_and_get_weights(self, TD_errors, batch):
# update errors in prioritized replay buffer # update errors in prioritized replay buffer
@@ -151,17 +152,18 @@ class ValueOptimizationAgent(Agent):
# this is fitted from the training dataset # this is fitted from the training dataset
for epoch in range(epochs): for epoch in range(epochs):
loss = 0 loss = 0
total_transitions_processed = 0
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)): for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
batch = Batch(batch) batch = Batch(batch)
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys)) current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards() current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
loss += self.networks['reward_model'].train_and_sync_networks( loss += self.networks['reward_model'].train_and_sync_networks(
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0] batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
# print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0]) total_transitions_processed += batch.size
log = OrderedDict() log = OrderedDict()
log['Epoch'] = epoch log['Epoch'] = epoch
log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size) log['loss'] = loss / total_transitions_processed
screen.log_dict(log, prefix='Training Reward Model') screen.log_dict(log, prefix='Training Reward Model')

View File

@@ -1,3 +1,4 @@
from .q_head import QHead
from .categorical_q_head import CategoricalQHead from .categorical_q_head import CategoricalQHead
from .ddpg_actor_head import DDPGActor from .ddpg_actor_head import DDPGActor
from .dnd_q_head import DNDQHead from .dnd_q_head import DNDQHead
@@ -7,7 +8,6 @@ from .naf_head import NAFHead
from .policy_head import PolicyHead from .policy_head import PolicyHead
from .ppo_head import PPOHead from .ppo_head import PPOHead
from .ppo_v_head import PPOVHead from .ppo_v_head import PPOVHead
from .q_head import QHead
from .quantile_regression_q_head import QuantileRegressionQHead from .quantile_regression_q_head import QuantileRegressionQHead
from .rainbow_q_head import RainbowQHead from .rainbow_q_head import RainbowQHead
from .v_head import VHead from .v_head import VHead

View File

@@ -15,16 +15,15 @@
# #
import tensorflow as tf import tensorflow as tf
import numpy as np
from rl_coach.architectures.tensorflow_components.heads import QHead
from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
from rl_coach.spaces import SpacesDefinition from rl_coach.spaces import SpacesDefinition
class CategoricalQHead(Head): class CategoricalQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu', head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu',
dense_layer=Dense): dense_layer=Dense):
@@ -33,7 +32,9 @@ class CategoricalQHead(Head):
self.name = 'categorical_dqn_head' self.name = 'categorical_dqn_head'
self.num_actions = len(self.spaces.action.actions) self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms self.num_atoms = agent_parameters.algorithm.atoms
self.return_type = QActionStateValue self.z_values = tf.cast(tf.constant(np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max,
self.ap.algorithm.atoms), dtype=tf.float32), dtype=tf.float64)
self.loss_type = []
def _build_module(self, input_layer): def _build_module(self, input_layer):
values_distribution = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='output') values_distribution = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='output')
@@ -49,6 +50,11 @@ class CategoricalQHead(Head):
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution) self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
tf.losses.add_loss(self.loss) tf.losses.add_loss(self.loss)
self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.z_values, 1)
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def __str__(self): def __str__(self):
result = [ result = [
"Dense (num outputs = {})".format(self.num_actions * self.num_atoms), "Dense (num outputs = {})".format(self.num_actions * self.num_atoms),

View File

@@ -56,11 +56,14 @@ class DNDQHead(QHead):
# Retrieve info from DND dictionary # Retrieve info from DND dictionary
# We assume that all actions have enough entries in the DND # We assume that all actions have enough entries in the DND
self.output = tf.transpose([ self.q_values = self.output = tf.transpose([
self._q_value(input_layer, action) self._q_value(input_layer, action)
for action in range(self.num_actions) for action in range(self.num_actions)
]) ])
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def _q_value(self, input_layer, action): def _q_value(self, input_layer, action):
result = tf.py_func(self.DND.query, result = tf.py_func(self.DND.query,
[input_layer, action, self.number_of_nn], [input_layer, action, self.number_of_nn],

View File

@@ -44,7 +44,10 @@ class DuelingQHead(QHead):
self.action_advantage = self.action_advantage - self.action_mean self.action_advantage = self.action_advantage - self.action_mean
# merge to state-action value function Q # merge to state-action value function Q
self.output = tf.add(self.state_value, self.action_advantage, name='output') self.q_values = self.output = tf.add(self.state_value, self.action_advantage, name='output')
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def __str__(self): def __str__(self):
result = [ result = [

View File

@@ -48,15 +48,19 @@ class QHead(Head):
def _build_module(self, input_layer): def _build_module(self, input_layer):
# Standard Q Network # Standard Q Network
self.output = self.dense_layer(self.num_actions)(input_layer, name='output') self.q_values = self.output = self.dense_layer(self.num_actions)(input_layer, name='output')
# TODO add this to other Q heads. e.g. dueling. # used in batch-rl to estimate a probablity distribution over actions
temperature = self.ap.network_wrappers[self.network_name].softmax_temperature self.softmax = self.add_softmax_with_temperature()
temperature_scaled_outputs = self.output / temperature
self.softmax = tf.nn.softmax(temperature_scaled_outputs, name="softmax")
def __str__(self): def __str__(self):
result = [ result = [
"Dense (num outputs = {})".format(self.num_actions) "Dense (num outputs = {})".format(self.num_actions)
] ]
return '\n'.join(result) return '\n'.join(result)
def add_softmax_with_temperature(self):
temperature = self.ap.network_wrappers[self.network_name].softmax_temperature
temperature_scaled_outputs = self.q_values / temperature
return tf.nn.softmax(temperature_scaled_outputs, name="softmax")

View File

@@ -15,15 +15,14 @@
# #
import tensorflow as tf import tensorflow as tf
import numpy as np
from rl_coach.architectures.tensorflow_components.heads import QHead
from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
from rl_coach.spaces import SpacesDefinition from rl_coach.spaces import SpacesDefinition
class QuantileRegressionQHead(Head): class QuantileRegressionQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu', head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense): dense_layer=Dense):
@@ -33,7 +32,10 @@ class QuantileRegressionQHead(Head):
self.num_actions = len(self.spaces.action.actions) self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms # we use atom / quantile interchangeably self.num_atoms = agent_parameters.algorithm.atoms # we use atom / quantile interchangeably
self.huber_loss_interval = agent_parameters.algorithm.huber_loss_interval # k self.huber_loss_interval = agent_parameters.algorithm.huber_loss_interval # k
self.return_type = QActionStateValue self.quantile_probabilities = tf.cast(
tf.constant(np.ones(self.ap.algorithm.atoms) / float(self.ap.algorithm.atoms), dtype=tf.float32),
dtype=tf.float64)
self.loss_type = []
def _build_module(self, input_layer): def _build_module(self, input_layer):
self.actions = tf.placeholder(tf.int32, [None, 2], name="actions") self.actions = tf.placeholder(tf.int32, [None, 2], name="actions")
@@ -72,6 +74,11 @@ class QuantileRegressionQHead(Head):
self.loss = quantile_regression_loss self.loss = quantile_regression_loss
tf.losses.add_loss(self.loss) tf.losses.add_loss(self.loss)
self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.quantile_probabilities, 1)
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def __str__(self): def __str__(self):
result = [ result = [
"Dense (num outputs = {})".format(self.num_actions * self.num_atoms), "Dense (num outputs = {})".format(self.num_actions * self.num_atoms),

View File

@@ -15,15 +15,14 @@
# #
import tensorflow as tf import tensorflow as tf
import numpy as np
from rl_coach.architectures.tensorflow_components.heads import QHead
from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
from rl_coach.spaces import SpacesDefinition from rl_coach.spaces import SpacesDefinition
class RainbowQHead(Head): class RainbowQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu', head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense): dense_layer=Dense):
@@ -31,8 +30,10 @@ class RainbowQHead(Head):
dense_layer=dense_layer) dense_layer=dense_layer)
self.num_actions = len(self.spaces.action.actions) self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms self.num_atoms = agent_parameters.algorithm.atoms
self.return_type = QActionStateValue
self.name = 'rainbow_q_values_head' self.name = 'rainbow_q_values_head'
self.z_values = tf.cast(tf.constant(np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max,
self.ap.algorithm.atoms), dtype=tf.float32), dtype=tf.float64)
self.loss_type = []
def _build_module(self, input_layer): def _build_module(self, input_layer):
# state value tower - V # state value tower - V
@@ -63,6 +64,11 @@ class RainbowQHead(Head):
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution) self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
tf.losses.add_loss(self.loss) tf.losses.add_loss(self.loss)
self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.z_values, 1)
# used in batch-rl to estimate a probablity distribution over actions
self.softmax = self.add_softmax_with_temperature()
def __str__(self): def __str__(self):
result = [ result = [
"State Value Stream - V", "State Value Stream - V",

View File

@@ -17,6 +17,7 @@ from copy import deepcopy
from typing import Tuple, List, Union from typing import Tuple, List, Union
from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.agents.nec_agent import NECAgentParameters
from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \ from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \
PresetValidationParameters PresetValidationParameters
from rl_coach.core_types import RunPhase from rl_coach.core_types import RunPhase
@@ -65,8 +66,8 @@ class BatchRLGraphManager(BasicRLGraphManager):
else: else:
env = None env = None
# Only DQN variants are supported at this point. # Only DQN variants and NEC are supported at this point.
assert(isinstance(self.agent_params, DQNAgentParameters)) assert(isinstance(self.agent_params, DQNAgentParameters) or isinstance(self.agent_params, NECAgentParameters))
# Only Episodic memories are supported, # Only Episodic memories are supported,
# for evaluating the sequential doubly robust estimator # for evaluating the sequential doubly robust estimator
assert(isinstance(self.agent_params.memory, EpisodicExperienceReplayParameters)) assert(isinstance(self.agent_params.memory, EpisodicExperienceReplayParameters))

View File

@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
# #
import ast import ast
import math
import pandas as pd import pandas as pd
from typing import List, Tuple, Union from typing import List, Tuple, Union
@@ -163,9 +164,8 @@ class EpisodicExperienceReplay(Memory):
shuffled_transition_indices = list(range(self.last_training_set_transition_id)) shuffled_transition_indices = list(range(self.last_training_set_transition_id))
random.shuffle(shuffled_transition_indices) random.shuffle(shuffled_transition_indices)
# we deliberately drop some of the ending data which is left after dividing to batches of size `size` # The last batch drawn will usually be < batch_size (=the size variable)
# for i in range(math.ceil(len(shuffled_transition_indices) / size)): for i in range(math.ceil(len(shuffled_transition_indices) / size)):
for i in range(int(len(shuffled_transition_indices) / size)):
sample_data = [self.transitions[j] for j in shuffled_transition_indices[i * size: (i + 1) * size]] sample_data = [self.transitions[j] for j in shuffled_transition_indices[i * size: (i + 1) * size]]
self.reader_writer_lock.release_writing() self.reader_writer_lock.release_writing()

View File

@@ -113,10 +113,6 @@ class ExperienceReplay(Memory):
yield sample_data yield sample_data
## usage example
# for o in random_seq_generator(list(range(10)), 4):
# print(o)
def _enforce_max_length(self) -> None: def _enforce_max_length(self) -> None:
""" """
Make sure that the size of the replay buffer does not pass the maximum size allowed. Make sure that the size of the replay buffer does not pass the maximum size allowed.

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import math
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
@@ -55,19 +56,18 @@ class OpeManager(object):
all_reward_model_rewards, all_policy_probs, all_old_policy_probs = [], [], [] all_reward_model_rewards, all_policy_probs, all_old_policy_probs = [], [], []
all_v_values_reward_model_based, all_v_values_q_model_based, all_rewards, all_actions = [], [], [], [] all_v_values_reward_model_based, all_v_values_q_model_based, all_rewards, all_actions = [], [], [], []
for i in range(int(len(dataset_as_transitions) / batch_size) + 1): for i in range(math.ceil(len(dataset_as_transitions) / batch_size)):
batch = dataset_as_transitions[i * batch_size: (i + 1) * batch_size] batch = dataset_as_transitions[i * batch_size: (i + 1) * batch_size]
batch_for_inference = Batch(batch) batch_for_inference = Batch(batch)
all_reward_model_rewards.append(reward_model.predict( all_reward_model_rewards.append(reward_model.predict(
batch_for_inference.states(network_keys))) batch_for_inference.states(network_keys)))
# TODO can we get rid of the 'output_heads[0]', and have some way of a cleaner API? # we always use the first Q head to calculate OPEs. might want to change this in the future.
# for instance, this means that for bootstrapped we always use the first QHead to calculate the OPEs.
q_values, sm_values = q_network.predict(batch_for_inference.states(network_keys), q_values, sm_values = q_network.predict(batch_for_inference.states(network_keys),
outputs=[q_network.output_heads[0].output, outputs=[q_network.output_heads[0].q_values,
q_network.output_heads[0].softmax]) q_network.output_heads[0].softmax])
# TODO why is this needed?
q_values = q_values[0]
all_policy_probs.append(sm_values) all_policy_probs.append(sm_values)
all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * all_reward_model_rewards[-1], axis=1)) all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * all_reward_model_rewards[-1], axis=1))

View File

@@ -0,0 +1,55 @@
from rl_coach.agents.qr_dqn_agent import QuantileRegressionDQNAgentParameters
from rl_coach.agents.rainbow_dqn_agent import RainbowDQNAgentParameters
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.gym_environment import GymVectorEnvironment
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.memories.memory import MemoryGranularity
from rl_coach.schedules import LinearSchedule
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(10000000000)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
schedule_params.heatup_steps = EnvironmentSteps(1000)
#########
# Agent #
#########
agent_params = QuantileRegressionDQNAgentParameters()
# DQN params
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)
agent_params.algorithm.discount = 0.99
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)
agent_params.algorithm.atoms = 50
# NN configuration
agent_params.network_wrappers['main'].learning_rate = 0.0005
# ER size
agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000)
# E-Greedy schedule
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
################
# Environment #
################
env_params = GymVectorEnvironment(level='CartPole-v0')
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 250
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=VisualizationParameters(),
preset_validation_params=preset_validation_params)