From 6e08c55ad5cb7a61fa997973d5d8f9cee2f9f864 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Thu, 21 Mar 2019 16:10:29 +0200 Subject: [PATCH] Enabling-more-agents-for-Batch-RL-and-cleanup (#258) allowing for the last training batch drawn to be smaller than batch_size + adding support for more agents in BatchRL by adding softmax with temperature to the corresponding heads + adding a CartPole_QR_DQN preset with a golden test + cleanups --- rl_coach/agents/bootstrapped_dqn_agent.py | 18 +++--- rl_coach/agents/categorical_dqn_agent.py | 10 ++-- rl_coach/agents/clipped_ppo_agent.py | 2 + rl_coach/agents/ddqn_agent.py | 2 +- rl_coach/agents/dfp_agent.py | 4 +- rl_coach/agents/dqn_agent.py | 2 +- rl_coach/agents/mmc_agent.py | 2 +- rl_coach/agents/nec_agent.py | 2 +- rl_coach/agents/pal_agent.py | 2 +- rl_coach/agents/qr_dqn_agent.py | 6 +- rl_coach/agents/rainbow_dqn_agent.py | 4 +- rl_coach/agents/value_optimization_agent.py | 12 ++-- .../tensorflow_components/heads/__init__.py | 2 +- .../heads/categorical_q_head.py | 16 ++++-- .../tensorflow_components/heads/dnd_q_head.py | 5 +- .../heads/dueling_q_head.py | 5 +- .../tensorflow_components/heads/q_head.py | 14 +++-- .../heads/quantile_regression_q_head.py | 17 ++++-- .../heads/rainbow_q_head.py | 16 ++++-- .../graph_managers/batch_rl_graph_manager.py | 5 +- .../episodic/episodic_experience_replay.py | 6 +- .../non_episodic/experience_replay.py | 4 -- rl_coach/off_policy_evaluators/ope_manager.py | 10 ++-- rl_coach/presets/CartPole_QR_DQN.py | 55 +++++++++++++++++++ 24 files changed, 152 insertions(+), 69 deletions(-) create mode 100644 rl_coach/presets/CartPole_QR_DQN.py diff --git a/rl_coach/agents/bootstrapped_dqn_agent.py b/rl_coach/agents/bootstrapped_dqn_agent.py index e8ee7d8..b291dff 100644 --- a/rl_coach/agents/bootstrapped_dqn_agent.py +++ b/rl_coach/agents/bootstrapped_dqn_agent.py @@ -18,11 +18,9 @@ from typing import Union import numpy as np -from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters +from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAgentParameters from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent -from rl_coach.base_parameters import AgentParameters from rl_coach.exploration_policies.bootstrapped import BootstrappedParameters -from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters class BootstrappedDQNNetworkParameters(DQNNetworkParameters): @@ -32,12 +30,11 @@ class BootstrappedDQNNetworkParameters(DQNNetworkParameters): self.heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/self.heads_parameters[0].num_output_head_copies -class BootstrappedDQNAgentParameters(AgentParameters): +class BootstrappedDQNAgentParameters(DQNAgentParameters): def __init__(self): - super().__init__(algorithm=DQNAlgorithmParameters(), - exploration=BootstrappedParameters(), - memory=ExperienceReplayParameters(), - networks={"main": BootstrappedDQNNetworkParameters()}) + super().__init__() + self.exploration = BootstrappedParameters() + self.network_wrappers = {"main": BootstrappedDQNNetworkParameters()} @property def path(self): @@ -65,13 +62,14 @@ class BootstrappedDQNAgent(ValueOptimizationAgent): TD_targets = result[self.ap.exploration.architecture_num_q_heads:] # add Q value samples for logging - self.q_values.add_sample(TD_targets) # initialize with the current prediction so that we will # only update the action that we have actually done in this transition - for i in range(self.ap.network_wrappers['main'].batch_size): + for i in range(batch.size): mask = batch[i].info['mask'] for head_idx in range(self.ap.exploration.architecture_num_q_heads): + self.q_values.add_sample(TD_targets[head_idx]) + if mask[head_idx] == 1: selected_action = np.argmax(next_states_online_values[head_idx][i], 0) TD_targets[head_idx][i, batch.actions()[i]] = \ diff --git a/rl_coach/agents/categorical_dqn_agent.py b/rl_coach/agents/categorical_dqn_agent.py index cfcbe9d..59738e3 100644 --- a/rl_coach/agents/categorical_dqn_agent.py +++ b/rl_coach/agents/categorical_dqn_agent.py @@ -84,8 +84,8 @@ class CategoricalDQNAgent(ValueOptimizationAgent): # prediction's format is (batch,actions,atoms) def get_all_q_values_for_states(self, states: StateType): if self.exploration_policy.requires_action_values(): - prediction = self.get_prediction(states) - q_values = self.distribution_prediction_to_q_values(prediction) + q_values = self.get_prediction(states, + outputs=self.networks['main'].online_network.output_heads[0].q_values) else: q_values = None return q_values @@ -105,9 +105,9 @@ class CategoricalDQNAgent(ValueOptimizationAgent): # select the optimal actions for the next state target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1) - m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) + m = np.zeros((batch.size, self.z_values.size)) - batches = np.arange(self.ap.network_wrappers['main'].batch_size) + batches = np.arange(batch.size) # an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping. # only 10% speedup overall - leaving commented out as the code is not as clear. @@ -120,7 +120,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent): # bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0]) # u_ = (np.ceil(bj_)).astype(int) # l_ = (np.floor(bj_)).astype(int) - # m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) + # m_ = np.zeros((batch.size, self.z_values.size)) # np.add.at(m_, [batches, l_], # np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_)) # np.add.at(m_, [batches, u_], diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index 71ccdce..cc29f33 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -207,6 +207,8 @@ class ClippedPPOAgent(ActorCriticAgent): self.networks['main'].online_network.output_heads[1].likelihood_ratio, self.networks['main'].online_network.output_heads[1].clipped_likelihood_ratio] + # TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on + # some of the data for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size)): start = i * self.ap.network_wrappers['main'].batch_size end = (i + 1) * self.ap.network_wrappers['main'].batch_size diff --git a/rl_coach/agents/ddqn_agent.py b/rl_coach/agents/ddqn_agent.py index 7021f8e..8100d37 100644 --- a/rl_coach/agents/ddqn_agent.py +++ b/rl_coach/agents/ddqn_agent.py @@ -56,7 +56,7 @@ class DDQNAgent(ValueOptimizationAgent): # initialize with the current prediction so that we will # only update the action that we have actually done in this transition TD_errors = [] - for i in range(self.ap.network_wrappers['main'].batch_size): + for i in range(batch.size): new_target = batch.rewards()[i] + \ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * q_st_plus_1[i][selected_actions[i]] TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]])) diff --git a/rl_coach/agents/dfp_agent.py b/rl_coach/agents/dfp_agent.py index bc989bc..cbce242 100644 --- a/rl_coach/agents/dfp_agent.py +++ b/rl_coach/agents/dfp_agent.py @@ -146,13 +146,13 @@ class DFPAgent(Agent): network_inputs = batch.states(network_keys) network_inputs['goal'] = np.repeat(np.expand_dims(self.current_goal, 0), - self.ap.network_wrappers['main'].batch_size, axis=0) + batch.size, axis=0) # get the current outputs of the network targets = self.networks['main'].online_network.predict(network_inputs) # change the targets for the taken actions - for i in range(self.ap.network_wrappers['main'].batch_size): + for i in range(batch.size): targets[i, batch.actions()[i]] = batch[i].info['future_measurements'].flatten() result = self.networks['main'].train_and_sync_networks(network_inputs, targets) diff --git a/rl_coach/agents/dqn_agent.py b/rl_coach/agents/dqn_agent.py index d6c05da..cbc419f 100644 --- a/rl_coach/agents/dqn_agent.py +++ b/rl_coach/agents/dqn_agent.py @@ -86,7 +86,7 @@ class DQNAgent(ValueOptimizationAgent): # only update the action that we have actually done in this transition TD_errors = [] - for i in range(self.ap.network_wrappers['main'].batch_size): + for i in range(batch.size): new_target = batch.rewards()[i] +\ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * np.max(q_st_plus_1[i], 0) TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]])) diff --git a/rl_coach/agents/mmc_agent.py b/rl_coach/agents/mmc_agent.py index dc2765d..e0ce76d 100644 --- a/rl_coach/agents/mmc_agent.py +++ b/rl_coach/agents/mmc_agent.py @@ -65,7 +65,7 @@ class MixedMonteCarloAgent(ValueOptimizationAgent): total_returns = batch.n_step_discounted_rewards() - for i in range(self.ap.network_wrappers['main'].batch_size): + for i in range(batch.size): one_step_target = batch.rewards()[i] + \ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \ q_st_plus_1[i][selected_actions[i]] diff --git a/rl_coach/agents/nec_agent.py b/rl_coach/agents/nec_agent.py index 9eabb78..ce1cfe1 100644 --- a/rl_coach/agents/nec_agent.py +++ b/rl_coach/agents/nec_agent.py @@ -133,7 +133,7 @@ class NECAgent(ValueOptimizationAgent): TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys)) bootstrapped_return_from_old_policy = batch.n_step_discounted_rewards() # only update the action that we have actually done in this transition - for i in range(self.ap.network_wrappers['main'].batch_size): + for i in range(batch.size): TD_targets[i, batch.actions()[i]] = bootstrapped_return_from_old_policy[i] # set the gradients to fetch for the DND update diff --git a/rl_coach/agents/pal_agent.py b/rl_coach/agents/pal_agent.py index 5256b7f..44778d6 100644 --- a/rl_coach/agents/pal_agent.py +++ b/rl_coach/agents/pal_agent.py @@ -84,7 +84,7 @@ class PALAgent(ValueOptimizationAgent): # calculate TD error TD_targets = np.copy(q_st_online) total_returns = batch.n_step_discounted_rewards() - for i in range(self.ap.network_wrappers['main'].batch_size): + for i in range(batch.size): TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \ q_st_plus_1_target[i][selected_actions[i]] diff --git a/rl_coach/agents/qr_dqn_agent.py b/rl_coach/agents/qr_dqn_agent.py index d5cf3fd..1e4042a 100644 --- a/rl_coach/agents/qr_dqn_agent.py +++ b/rl_coach/agents/qr_dqn_agent.py @@ -95,7 +95,7 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent): target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1) # calculate the Bellman update - batch_idx = list(range(self.ap.network_wrappers['main'].batch_size)) + batch_idx = list(range(batch.size)) TD_targets = batch.rewards(True) + (1.0 - batch.game_overs(True)) * self.ap.algorithm.discount \ * next_state_quantiles[batch_idx, target_actions] @@ -106,9 +106,9 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent): # calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order cumulative_probabilities = np.array(range(self.ap.algorithm.atoms + 1)) / float(self.ap.algorithm.atoms) # tau_i quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i - quantile_midpoints = np.tile(quantile_midpoints, (self.ap.network_wrappers['main'].batch_size, 1)) + quantile_midpoints = np.tile(quantile_midpoints, (batch.size, 1)) sorted_quantiles = np.argsort(current_quantiles[batch_idx, batch.actions()]) - for idx in range(self.ap.network_wrappers['main'].batch_size): + for idx in range(batch.size): quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]] # train diff --git a/rl_coach/agents/rainbow_dqn_agent.py b/rl_coach/agents/rainbow_dqn_agent.py index 08e417c..68de193 100644 --- a/rl_coach/agents/rainbow_dqn_agent.py +++ b/rl_coach/agents/rainbow_dqn_agent.py @@ -103,9 +103,9 @@ class RainbowDQNAgent(CategoricalDQNAgent): # only update the action that we have actually done in this transition (using the Double-DQN selected actions) target_actions = ddqn_selected_actions - m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) + m = np.zeros((batch.size, self.z_values.size)) - batches = np.arange(self.ap.network_wrappers['main'].batch_size) + batches = np.arange(batch.size) for j in range(self.z_values.size): # we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step, # we will not bootstrap for the last n-step transitions in the episode diff --git a/rl_coach/agents/value_optimization_agent.py b/rl_coach/agents/value_optimization_agent.py index 917435f..1ed635d 100644 --- a/rl_coach/agents/value_optimization_agent.py +++ b/rl_coach/agents/value_optimization_agent.py @@ -50,8 +50,9 @@ class ValueOptimizationAgent(Agent): actions_q_values = None return actions_q_values - def get_prediction(self, states): - return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main')) + def get_prediction(self, states, outputs=None): + return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'), + outputs=outputs) def update_transition_priorities_and_get_weights(self, TD_errors, batch): # update errors in prioritized replay buffer @@ -151,17 +152,18 @@ class ValueOptimizationAgent(Agent): # this is fitted from the training dataset for epoch in range(epochs): loss = 0 + total_transitions_processed = 0 for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)): batch = Batch(batch) current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys)) - current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards() + current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards() loss += self.networks['reward_model'].train_and_sync_networks( batch.states(network_keys), current_rewards_prediction_for_all_actions)[0] - # print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0]) + total_transitions_processed += batch.size log = OrderedDict() log['Epoch'] = epoch - log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size) + log['loss'] = loss / total_transitions_processed screen.log_dict(log, prefix='Training Reward Model') diff --git a/rl_coach/architectures/tensorflow_components/heads/__init__.py b/rl_coach/architectures/tensorflow_components/heads/__init__.py index 5631fec..daf5492 100644 --- a/rl_coach/architectures/tensorflow_components/heads/__init__.py +++ b/rl_coach/architectures/tensorflow_components/heads/__init__.py @@ -1,3 +1,4 @@ +from .q_head import QHead from .categorical_q_head import CategoricalQHead from .ddpg_actor_head import DDPGActor from .dnd_q_head import DNDQHead @@ -7,7 +8,6 @@ from .naf_head import NAFHead from .policy_head import PolicyHead from .ppo_head import PPOHead from .ppo_v_head import PPOVHead -from .q_head import QHead from .quantile_regression_q_head import QuantileRegressionQHead from .rainbow_q_head import RainbowQHead from .v_head import VHead diff --git a/rl_coach/architectures/tensorflow_components/heads/categorical_q_head.py b/rl_coach/architectures/tensorflow_components/heads/categorical_q_head.py index 1f19a59..56427ed 100644 --- a/rl_coach/architectures/tensorflow_components/heads/categorical_q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/categorical_q_head.py @@ -15,16 +15,15 @@ # import tensorflow as tf - +import numpy as np +from rl_coach.architectures.tensorflow_components.heads import QHead from rl_coach.architectures.tensorflow_components.layers import Dense -from rl_coach.architectures.tensorflow_components.heads.head import Head from rl_coach.base_parameters import AgentParameters -from rl_coach.core_types import QActionStateValue from rl_coach.spaces import SpacesDefinition -class CategoricalQHead(Head): +class CategoricalQHead(QHead): def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu', dense_layer=Dense): @@ -33,7 +32,9 @@ class CategoricalQHead(Head): self.name = 'categorical_dqn_head' self.num_actions = len(self.spaces.action.actions) self.num_atoms = agent_parameters.algorithm.atoms - self.return_type = QActionStateValue + self.z_values = tf.cast(tf.constant(np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max, + self.ap.algorithm.atoms), dtype=tf.float32), dtype=tf.float64) + self.loss_type = [] def _build_module(self, input_layer): values_distribution = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='output') @@ -49,6 +50,11 @@ class CategoricalQHead(Head): self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution) tf.losses.add_loss(self.loss) + self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.z_values, 1) + + # used in batch-rl to estimate a probablity distribution over actions + self.softmax = self.add_softmax_with_temperature() + def __str__(self): result = [ "Dense (num outputs = {})".format(self.num_actions * self.num_atoms), diff --git a/rl_coach/architectures/tensorflow_components/heads/dnd_q_head.py b/rl_coach/architectures/tensorflow_components/heads/dnd_q_head.py index 5605cb7..6462f83 100644 --- a/rl_coach/architectures/tensorflow_components/heads/dnd_q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/dnd_q_head.py @@ -56,11 +56,14 @@ class DNDQHead(QHead): # Retrieve info from DND dictionary # We assume that all actions have enough entries in the DND - self.output = tf.transpose([ + self.q_values = self.output = tf.transpose([ self._q_value(input_layer, action) for action in range(self.num_actions) ]) + # used in batch-rl to estimate a probablity distribution over actions + self.softmax = self.add_softmax_with_temperature() + def _q_value(self, input_layer, action): result = tf.py_func(self.DND.query, [input_layer, action, self.number_of_nn], diff --git a/rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py b/rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py index 8237a91..92692ab 100644 --- a/rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py @@ -44,7 +44,10 @@ class DuelingQHead(QHead): self.action_advantage = self.action_advantage - self.action_mean # merge to state-action value function Q - self.output = tf.add(self.state_value, self.action_advantage, name='output') + self.q_values = self.output = tf.add(self.state_value, self.action_advantage, name='output') + + # used in batch-rl to estimate a probablity distribution over actions + self.softmax = self.add_softmax_with_temperature() def __str__(self): result = [ diff --git a/rl_coach/architectures/tensorflow_components/heads/q_head.py b/rl_coach/architectures/tensorflow_components/heads/q_head.py index 135639c..ecc1461 100644 --- a/rl_coach/architectures/tensorflow_components/heads/q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/q_head.py @@ -48,15 +48,19 @@ class QHead(Head): def _build_module(self, input_layer): # Standard Q Network - self.output = self.dense_layer(self.num_actions)(input_layer, name='output') + self.q_values = self.output = self.dense_layer(self.num_actions)(input_layer, name='output') - # TODO add this to other Q heads. e.g. dueling. - temperature = self.ap.network_wrappers[self.network_name].softmax_temperature - temperature_scaled_outputs = self.output / temperature - self.softmax = tf.nn.softmax(temperature_scaled_outputs, name="softmax") + # used in batch-rl to estimate a probablity distribution over actions + self.softmax = self.add_softmax_with_temperature() def __str__(self): result = [ "Dense (num outputs = {})".format(self.num_actions) ] return '\n'.join(result) + + def add_softmax_with_temperature(self): + temperature = self.ap.network_wrappers[self.network_name].softmax_temperature + temperature_scaled_outputs = self.q_values / temperature + return tf.nn.softmax(temperature_scaled_outputs, name="softmax") + diff --git a/rl_coach/architectures/tensorflow_components/heads/quantile_regression_q_head.py b/rl_coach/architectures/tensorflow_components/heads/quantile_regression_q_head.py index fa6e1e9..5edcbfb 100644 --- a/rl_coach/architectures/tensorflow_components/heads/quantile_regression_q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/quantile_regression_q_head.py @@ -15,15 +15,14 @@ # import tensorflow as tf - +import numpy as np +from rl_coach.architectures.tensorflow_components.heads import QHead from rl_coach.architectures.tensorflow_components.layers import Dense -from rl_coach.architectures.tensorflow_components.heads.head import Head from rl_coach.base_parameters import AgentParameters -from rl_coach.core_types import QActionStateValue from rl_coach.spaces import SpacesDefinition -class QuantileRegressionQHead(Head): +class QuantileRegressionQHead(QHead): def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu', dense_layer=Dense): @@ -33,7 +32,10 @@ class QuantileRegressionQHead(Head): self.num_actions = len(self.spaces.action.actions) self.num_atoms = agent_parameters.algorithm.atoms # we use atom / quantile interchangeably self.huber_loss_interval = agent_parameters.algorithm.huber_loss_interval # k - self.return_type = QActionStateValue + self.quantile_probabilities = tf.cast( + tf.constant(np.ones(self.ap.algorithm.atoms) / float(self.ap.algorithm.atoms), dtype=tf.float32), + dtype=tf.float64) + self.loss_type = [] def _build_module(self, input_layer): self.actions = tf.placeholder(tf.int32, [None, 2], name="actions") @@ -72,6 +74,11 @@ class QuantileRegressionQHead(Head): self.loss = quantile_regression_loss tf.losses.add_loss(self.loss) + self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.quantile_probabilities, 1) + + # used in batch-rl to estimate a probablity distribution over actions + self.softmax = self.add_softmax_with_temperature() + def __str__(self): result = [ "Dense (num outputs = {})".format(self.num_actions * self.num_atoms), diff --git a/rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py b/rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py index 2d2fb6e..f7f0ba4 100644 --- a/rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py @@ -15,15 +15,14 @@ # import tensorflow as tf - +import numpy as np +from rl_coach.architectures.tensorflow_components.heads import QHead from rl_coach.architectures.tensorflow_components.layers import Dense -from rl_coach.architectures.tensorflow_components.heads.head import Head from rl_coach.base_parameters import AgentParameters -from rl_coach.core_types import QActionStateValue from rl_coach.spaces import SpacesDefinition -class RainbowQHead(Head): +class RainbowQHead(QHead): def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu', dense_layer=Dense): @@ -31,8 +30,10 @@ class RainbowQHead(Head): dense_layer=dense_layer) self.num_actions = len(self.spaces.action.actions) self.num_atoms = agent_parameters.algorithm.atoms - self.return_type = QActionStateValue self.name = 'rainbow_q_values_head' + self.z_values = tf.cast(tf.constant(np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max, + self.ap.algorithm.atoms), dtype=tf.float32), dtype=tf.float64) + self.loss_type = [] def _build_module(self, input_layer): # state value tower - V @@ -63,6 +64,11 @@ class RainbowQHead(Head): self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution) tf.losses.add_loss(self.loss) + self.q_values = tf.tensordot(tf.cast(self.output, tf.float64), self.z_values, 1) + + # used in batch-rl to estimate a probablity distribution over actions + self.softmax = self.add_softmax_with_temperature() + def __str__(self): result = [ "State Value Stream - V", diff --git a/rl_coach/graph_managers/batch_rl_graph_manager.py b/rl_coach/graph_managers/batch_rl_graph_manager.py index 8bd90a0..ad65aef 100644 --- a/rl_coach/graph_managers/batch_rl_graph_manager.py +++ b/rl_coach/graph_managers/batch_rl_graph_manager.py @@ -17,6 +17,7 @@ from copy import deepcopy from typing import Tuple, List, Union from rl_coach.agents.dqn_agent import DQNAgentParameters +from rl_coach.agents.nec_agent import NECAgentParameters from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \ PresetValidationParameters from rl_coach.core_types import RunPhase @@ -65,8 +66,8 @@ class BatchRLGraphManager(BasicRLGraphManager): else: env = None - # Only DQN variants are supported at this point. - assert(isinstance(self.agent_params, DQNAgentParameters)) + # Only DQN variants and NEC are supported at this point. + assert(isinstance(self.agent_params, DQNAgentParameters) or isinstance(self.agent_params, NECAgentParameters)) # Only Episodic memories are supported, # for evaluating the sequential doubly robust estimator assert(isinstance(self.agent_params.memory, EpisodicExperienceReplayParameters)) diff --git a/rl_coach/memories/episodic/episodic_experience_replay.py b/rl_coach/memories/episodic/episodic_experience_replay.py index bc1e72a..f556c9d 100644 --- a/rl_coach/memories/episodic/episodic_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_experience_replay.py @@ -15,6 +15,7 @@ # limitations under the License. # import ast +import math import pandas as pd from typing import List, Tuple, Union @@ -163,9 +164,8 @@ class EpisodicExperienceReplay(Memory): shuffled_transition_indices = list(range(self.last_training_set_transition_id)) random.shuffle(shuffled_transition_indices) - # we deliberately drop some of the ending data which is left after dividing to batches of size `size` - # for i in range(math.ceil(len(shuffled_transition_indices) / size)): - for i in range(int(len(shuffled_transition_indices) / size)): + # The last batch drawn will usually be < batch_size (=the size variable) + for i in range(math.ceil(len(shuffled_transition_indices) / size)): sample_data = [self.transitions[j] for j in shuffled_transition_indices[i * size: (i + 1) * size]] self.reader_writer_lock.release_writing() diff --git a/rl_coach/memories/non_episodic/experience_replay.py b/rl_coach/memories/non_episodic/experience_replay.py index f47d9b6..4226ba4 100644 --- a/rl_coach/memories/non_episodic/experience_replay.py +++ b/rl_coach/memories/non_episodic/experience_replay.py @@ -113,10 +113,6 @@ class ExperienceReplay(Memory): yield sample_data - ## usage example - # for o in random_seq_generator(list(range(10)), 4): - # print(o) - def _enforce_max_length(self) -> None: """ Make sure that the size of the replay buffer does not pass the maximum size allowed. diff --git a/rl_coach/off_policy_evaluators/ope_manager.py b/rl_coach/off_policy_evaluators/ope_manager.py index 1a64621..81b11ea 100644 --- a/rl_coach/off_policy_evaluators/ope_manager.py +++ b/rl_coach/off_policy_evaluators/ope_manager.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import math from collections import namedtuple import numpy as np @@ -55,19 +56,18 @@ class OpeManager(object): all_reward_model_rewards, all_policy_probs, all_old_policy_probs = [], [], [] all_v_values_reward_model_based, all_v_values_q_model_based, all_rewards, all_actions = [], [], [], [] - for i in range(int(len(dataset_as_transitions) / batch_size) + 1): + for i in range(math.ceil(len(dataset_as_transitions) / batch_size)): batch = dataset_as_transitions[i * batch_size: (i + 1) * batch_size] batch_for_inference = Batch(batch) all_reward_model_rewards.append(reward_model.predict( batch_for_inference.states(network_keys))) - # TODO can we get rid of the 'output_heads[0]', and have some way of a cleaner API? + # we always use the first Q head to calculate OPEs. might want to change this in the future. + # for instance, this means that for bootstrapped we always use the first QHead to calculate the OPEs. q_values, sm_values = q_network.predict(batch_for_inference.states(network_keys), - outputs=[q_network.output_heads[0].output, + outputs=[q_network.output_heads[0].q_values, q_network.output_heads[0].softmax]) - # TODO why is this needed? - q_values = q_values[0] all_policy_probs.append(sm_values) all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * all_reward_model_rewards[-1], axis=1)) diff --git a/rl_coach/presets/CartPole_QR_DQN.py b/rl_coach/presets/CartPole_QR_DQN.py new file mode 100644 index 0000000..2694c63 --- /dev/null +++ b/rl_coach/presets/CartPole_QR_DQN.py @@ -0,0 +1,55 @@ +from rl_coach.agents.qr_dqn_agent import QuantileRegressionDQNAgentParameters +from rl_coach.agents.rainbow_dqn_agent import RainbowDQNAgentParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps +from rl_coach.environments.gym_environment import GymVectorEnvironment +from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.memories.memory import MemoryGranularity +from rl_coach.schedules import LinearSchedule + +#################### +# Graph Scheduling # +#################### + +schedule_params = ScheduleParameters() +schedule_params.improve_steps = TrainingSteps(10000000000) +schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10) +schedule_params.evaluation_steps = EnvironmentEpisodes(1) +schedule_params.heatup_steps = EnvironmentSteps(1000) + +######### +# Agent # +######### +agent_params = QuantileRegressionDQNAgentParameters() + +# DQN params +agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100) +agent_params.algorithm.discount = 0.99 +agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1) +agent_params.algorithm.atoms = 50 +# NN configuration +agent_params.network_wrappers['main'].learning_rate = 0.0005 + +# ER size +agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000) + +# E-Greedy schedule +agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) + +################ +# Environment # +################ +env_params = GymVectorEnvironment(level='CartPole-v0') + +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test = True +preset_validation_params.min_reward_threshold = 150 +preset_validation_params.max_episodes_to_achieve_reward = 250 + +graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, + schedule_params=schedule_params, vis_params=VisualizationParameters(), + preset_validation_params=preset_validation_params)