From 49dea39d34a562a46daa562051a66846326f66f4 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Wed, 7 Nov 2018 18:33:08 +0200 Subject: [PATCH] N-step returns for rainbow (#67) * n_step returns for rainbow * Rename CartPole_PPO -> CartPole_ClippedPPO --- rl_coach/agents/agent.py | 9 +- rl_coach/agents/categorical_dqn_agent.py | 36 ++++-- rl_coach/agents/clipped_ppo_agent.py | 8 +- rl_coach/agents/mmc_agent.py | 4 +- rl_coach/agents/nec_agent.py | 8 +- rl_coach/agents/pal_agent.py | 3 +- rl_coach/agents/policy_gradients_agent.py | 2 +- rl_coach/agents/policy_optimization_agent.py | 6 +- rl_coach/agents/ppo_agent.py | 7 +- rl_coach/agents/rainbow_dqn_agent.py | 37 +++--- rl_coach/base_parameters.py | 3 + rl_coach/core_types.py | 107 +++++++++++------- .../episodic/episodic_experience_replay.py | 16 ++- .../episodic_hindsight_experience_replay.py | 2 +- .../prioritized_experience_replay.py | 28 +++-- ...CartPole_PPO.py => CartPole_ClippedPPO.py} | 2 +- .../test_prioritized_experience_replay.py | 8 +- .../memories/test_single_episode_buffer.py | 4 +- 18 files changed, 173 insertions(+), 117 deletions(-) rename rl_coach/presets/{CartPole_PPO.py => CartPole_ClippedPPO.py} (98%) diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 0aa258a..e3b116c 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -116,7 +116,6 @@ class Agent(AgentInterface): self.output_filter.set_device(device) self.pre_network_filter.set_device(device) - # initialize all internal variables self._phase = RunPhase.HEATUP self.total_shaped_reward_in_current_episode = 0 @@ -143,7 +142,7 @@ class Agent(AgentInterface): self.accumulated_shaped_rewards_across_evaluation_episodes = 0 self.num_successes_across_evaluation_episodes = 0 self.num_evaluation_episodes_completed = 0 - self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount) + self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount, n_step=self.ap.algorithm.n_step) # TODO: add agents observation rendering for debugging purposes (not the same as the environment rendering) # environment parameters @@ -452,10 +451,10 @@ class Agent(AgentInterface): :return: None """ self.current_episode_buffer.is_complete = True - self.current_episode_buffer.update_returns() + self.current_episode_buffer.update_transitions_rewards_and_bootstrap_data() for transition in self.current_episode_buffer.transitions: - self.discounted_return.add_sample(transition.total_return) + self.discounted_return.add_sample(transition.n_step_discounted_rewards) if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only: self.current_episode += 1 @@ -497,7 +496,7 @@ class Agent(AgentInterface): self.curr_state = {} self.current_episode_steps_counter = 0 self.episode_running_info = {} - self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount) + self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount, n_step=self.ap.algorithm.n_step) if self.exploration_policy: self.exploration_policy.reset() self.input_filter.reset() diff --git a/rl_coach/agents/categorical_dqn_agent.py b/rl_coach/agents/categorical_dqn_agent.py index 24a610b..bca506a 100644 --- a/rl_coach/agents/categorical_dqn_agent.py +++ b/rl_coach/agents/categorical_dqn_agent.py @@ -17,14 +17,11 @@ from typing import Union import numpy as np - from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters, DQNAgentParameters from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent from rl_coach.architectures.head_parameters import CategoricalQHeadParameters -from rl_coach.base_parameters import AgentParameters from rl_coach.core_types import StateType from rl_coach.exploration_policies.e_greedy import EGreedyParameters -from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay from rl_coach.schedules import LinearSchedule @@ -85,28 +82,47 @@ class CategoricalDQNAgent(ValueOptimizationAgent): # for the action we actually took, the error is calculated by the atoms distribution # for all other actions, the error is 0 - distributed_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([ + distributional_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([ (self.networks['main'].target_network, batch.next_states(network_keys)), (self.networks['main'].online_network, batch.states(network_keys)) ]) - # only update the action that we have actually done in this transition - target_actions = np.argmax(self.distribution_prediction_to_q_values(distributed_q_st_plus_1), axis=1) + # select the optimal actions for the next state + target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1) m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) batches = np.arange(self.ap.network_wrappers['main'].batch_size) + + # an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping. + # only 10% speedup overall - leaving commented out as the code is not as clear. + + # tzj_ = np.fmax(np.fmin(batch.rewards() + (1.0 - batch.game_overs()) * self.ap.algorithm.discount * + # np.transpose(np.repeat(self.z_values[np.newaxis, :], batch.size, axis=0), (1, 0)), + # self.z_values[-1]), + # self.z_values[0]) + # + # bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0]) + # u_ = (np.ceil(bj_)).astype(int) + # l_ = (np.floor(bj_)).astype(int) + # m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) + # np.add.at(m_, [batches, l_], + # np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_)) + # np.add.at(m_, [batches, u_], + # np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (bj_ - l_)) + for j in range(self.z_values.size): tzj = np.fmax(np.fmin(batch.rewards() + (1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j], - self.z_values[self.z_values.size - 1]), + self.z_values[-1]), self.z_values[0]) bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0]) u = (np.ceil(bj)).astype(int) l = (np.floor(bj)).astype(int) - m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj)) - m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l)) + m[batches, l] += (distributional_q_st_plus_1[batches, target_actions, j] * (u - bj)) + m[batches, u] += (distributional_q_st_plus_1[batches, target_actions, j] * (bj - l)) # total_loss = cross entropy between actual result above and predicted result for the given action + # only update the action that we have actually done in this transition TD_targets[batches, batch.actions()] = m # update errors in prioritized replay buffer @@ -120,7 +136,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent): # TODO: fix this spaghetti code if isinstance(self.memory, PrioritizedExperienceReplay): errors = losses[0][np.arange(batch.size), batch.actions()] - self.memory.update_priorities(batch.info('idx'), errors) + self.call_memory('update_priorities', (batch.info('idx'), errors)) return total_loss, losses, unclipped_grads diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index 080525f..c581736 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -116,8 +116,10 @@ class ClippedPPOAgent(ActorCriticAgent): # calculate advantages advantages = [] value_targets = [] + total_returns = batch.n_step_discounted_rewards() + if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE: - advantages = batch.total_returns() - current_state_values + advantages = total_returns - current_state_values elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE: # get bootstraps episode_start_idx = 0 @@ -181,11 +183,13 @@ class ClippedPPOAgent(ActorCriticAgent): result = self.networks['main'].target_network.predict({k: v[start:end] for k, v in batch.states(network_keys).items()}) old_policy_distribution = result[1:] + total_returns = batch.n_step_discounted_rewards(expand_dims=True) + # calculate gradients and apply on both the local policy network and on the global policy network if self.ap.algorithm.estimate_state_value_using_gae: value_targets = np.expand_dims(gae_based_value_targets, -1) else: - value_targets = batch.total_returns(expand_dims=True)[start:end] + value_targets = total_returns[start:end] inputs = copy.copy({k: v[start:end] for k, v in batch.states(network_keys).items()}) inputs['output_1_0'] = actions diff --git a/rl_coach/agents/mmc_agent.py b/rl_coach/agents/mmc_agent.py index 3ce23e1..964d922 100644 --- a/rl_coach/agents/mmc_agent.py +++ b/rl_coach/agents/mmc_agent.py @@ -58,11 +58,13 @@ class MixedMonteCarloAgent(ValueOptimizationAgent): (self.networks['main'].online_network, batch.states(network_keys)) ]) + total_returns = batch.n_step_discounted_rewards() + for i in range(self.ap.network_wrappers['main'].batch_size): one_step_target = batch.rewards()[i] + \ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \ q_st_plus_1[i][selected_actions[i]] - monte_carlo_target = batch.total_returns()[i] + monte_carlo_target = total_returns()[i] TD_targets[i, batch.actions()[i]] = (1 - self.mixing_rate) * one_step_target + \ self.mixing_rate * monte_carlo_target diff --git a/rl_coach/agents/nec_agent.py b/rl_coach/agents/nec_agent.py index 6f168bb..1ba8abe 100644 --- a/rl_coach/agents/nec_agent.py +++ b/rl_coach/agents/nec_agent.py @@ -98,10 +98,10 @@ class NECAgent(ValueOptimizationAgent): network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys)) - + bootstrapped_return_from_old_policy = batch.n_step_discounted_rewards() # only update the action that we have actually done in this transition for i in range(self.ap.network_wrappers['main'].batch_size): - TD_targets[i, batch.actions()[i]] = batch.total_returns()[i] + TD_targets[i, batch.actions()[i]] = bootstrapped_return_from_old_policy[i] # set the gradients to fetch for the DND update fetches = [] @@ -165,10 +165,10 @@ class NECAgent(ValueOptimizationAgent): episode = self.call_memory('get_last_complete_episode') if episode is not None and self.phase != RunPhase.TEST: assert len(self.current_episode_state_embeddings) == episode.length() - returns = episode.get_transitions_attribute('total_return') + discounted_rewards = episode.get_transitions_attribute('n_step_discounted_rewards') actions = episode.get_transitions_attribute('action') self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings, - actions, returns) + actions, discounted_rewards) def save_checkpoint(self, checkpoint_id): with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f: diff --git a/rl_coach/agents/pal_agent.py b/rl_coach/agents/pal_agent.py index cb928e7..cba983c 100644 --- a/rl_coach/agents/pal_agent.py +++ b/rl_coach/agents/pal_agent.py @@ -70,6 +70,7 @@ class PALAgent(ValueOptimizationAgent): # calculate TD error TD_targets = np.copy(q_st_online) + total_returns = batch.n_step_discounted_rewards() for i in range(self.ap.network_wrappers['main'].batch_size): TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \ @@ -83,7 +84,7 @@ class PALAgent(ValueOptimizationAgent): TD_targets[i, batch.actions()[i]] -= self.alpha * advantage_learning_update # mixing monte carlo updates - monte_carlo_target = batch.total_returns()[i] + monte_carlo_target = total_returns[i] TD_targets[i, batch.actions()[i]] = (1 - self.monte_carlo_mixing_rate) * TD_targets[i, batch.actions()[i]] \ + self.monte_carlo_mixing_rate * monte_carlo_target diff --git a/rl_coach/agents/policy_gradients_agent.py b/rl_coach/agents/policy_gradients_agent.py index 7db5fd8..95ff617 100644 --- a/rl_coach/agents/policy_gradients_agent.py +++ b/rl_coach/agents/policy_gradients_agent.py @@ -74,7 +74,7 @@ class PolicyGradientsAgent(PolicyOptimizationAgent): # batch contains a list of episodes to learn from network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() - total_returns = batch.total_returns() + total_returns = batch.n_step_discounted_rewards() for i in reversed(range(batch.size)): if self.policy_gradient_rescaler == PolicyGradientRescaler.TOTAL_RETURN: total_returns[i] = total_returns[0] diff --git a/rl_coach/agents/policy_optimization_agent.py b/rl_coach/agents/policy_optimization_agent.py index 18e81cb..390edcd 100644 --- a/rl_coach/agents/policy_optimization_agent.py +++ b/rl_coach/agents/policy_optimization_agent.py @@ -73,11 +73,11 @@ class PolicyOptimizationAgent(Agent): episode_discounted_returns = [] for i in range(episode.length()): transition = episode.get_transition(i) - episode_discounted_returns.append(transition.total_return) + episode_discounted_returns.append(transition.n_step_discounted_rewards) self.num_episodes_where_step_has_been_seen[i] += 1 self.mean_return_over_multiple_episodes[i] -= self.mean_return_over_multiple_episodes[i] / \ self.num_episodes_where_step_has_been_seen[i] - self.mean_return_over_multiple_episodes[i] += transition.total_return / \ + self.mean_return_over_multiple_episodes[i] += transition.n_step_discounted_rewards / \ self.num_episodes_where_step_has_been_seen[i] self.mean_discounted_return = np.mean(episode_discounted_returns) self.std_discounted_return = np.std(episode_discounted_returns) @@ -97,7 +97,7 @@ class PolicyOptimizationAgent(Agent): network.set_is_training(True) # we need to update the returns of the episode until now - episode.update_returns() + episode.update_transitions_rewards_and_bootstrap_data() # get t_max transitions or less if the we got to a terminal state # will be used for both actor-critic and vanilla PG. diff --git a/rl_coach/agents/ppo_agent.py b/rl_coach/agents/ppo_agent.py index 83b6fc4..d455caa 100644 --- a/rl_coach/agents/ppo_agent.py +++ b/rl_coach/agents/ppo_agent.py @@ -112,11 +112,11 @@ class PPOAgent(ActorCriticAgent): # current_states_with_timestep = self.concat_state_and_timestep(batch) current_state_values = self.networks['critic'].online_network.predict(batch.states(network_keys)).squeeze() - + total_returns = batch.n_step_discounted_rewards() # calculate advantages advantages = [] if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE: - advantages = batch.total_returns() - current_state_values + advantages = total_returns - current_state_values elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE: # get bootstraps episode_start_idx = 0 @@ -155,6 +155,7 @@ class PPOAgent(ActorCriticAgent): # current_states_with_timestep = self.concat_state_and_timestep(dataset) mix_fraction = self.ap.algorithm.value_targets_mix_fraction + total_returns = batch.n_step_discounted_rewards(True) for j in range(epochs): curr_batch_size = batch.size if self.networks['critic'].online_network.optimizer_type != 'LBFGS': @@ -165,7 +166,7 @@ class PPOAgent(ActorCriticAgent): k: v[i * curr_batch_size:(i + 1) * curr_batch_size] for k, v in batch.states(network_keys).items() } - total_return_batch = batch.total_returns(True)[i * curr_batch_size:(i + 1) * curr_batch_size] + total_return_batch = total_returns[i * curr_batch_size:(i + 1) * curr_batch_size] old_policy_values = force_list(self.networks['critic'].target_network.predict( current_states_batch).squeeze()) if self.networks['critic'].online_network.optimizer_type != 'LBFGS': diff --git a/rl_coach/agents/rainbow_dqn_agent.py b/rl_coach/agents/rainbow_dqn_agent.py index 609ea0b..446a870 100644 --- a/rl_coach/agents/rainbow_dqn_agent.py +++ b/rl_coach/agents/rainbow_dqn_agent.py @@ -39,23 +39,17 @@ class RainbowDQNNetworkParameters(DQNNetworkParameters): class RainbowDQNAlgorithmParameters(CategoricalDQNAlgorithmParameters): def __init__(self): super().__init__() + self.n_step = 3 - -class RainbowDQNExplorationParameters(ParameterNoiseParameters): - def __init__(self, agent_params): - super().__init__(agent_params) - - -class RainbowDQNMemoryParameters(PrioritizedExperienceReplayParameters): - def __init__(self): - super().__init__() + # needed for n-step updates to work. i.e. waiting for a full episode to be closed before storing each transition + self.store_transitions_only_when_episodes_are_terminated = True class RainbowDQNAgentParameters(CategoricalDQNAgentParameters): def __init__(self): super().__init__() self.algorithm = RainbowDQNAlgorithmParameters() - self.exploration = RainbowDQNExplorationParameters(self) + self.exploration = ParameterNoiseParameters(self) self.memory = PrioritizedExperienceReplayParameters() self.network_wrappers = {"main": RainbowDQNNetworkParameters()} @@ -65,15 +59,13 @@ class RainbowDQNAgentParameters(CategoricalDQNAgentParameters): # Rainbow Deep Q Network - https://arxiv.org/abs/1710.02298 -# Agent implementation is WIP. Currently is composed of: +# Agent implementation is composed of: # 1. NoisyNets # 2. C51 # 3. Prioritized ER # 4. DDQN # 5. Dueling DQN -# -# still missing: -# 1. N-Step +# 6. N-step returns class RainbowDQNAgent(CategoricalDQNAgent): def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None): @@ -87,7 +79,7 @@ class RainbowDQNAgent(CategoricalDQNAgent): # for the action we actually took, the error is calculated by the atoms distribution # for all other actions, the error is 0 - distributed_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([ + distributional_q_st_plus_n, TD_targets = self.networks['main'].parallel_prediction([ (self.networks['main'].target_network, batch.next_states(network_keys)), (self.networks['main'].online_network, batch.states(network_keys)) ]) @@ -98,15 +90,16 @@ class RainbowDQNAgent(CategoricalDQNAgent): batches = np.arange(self.ap.network_wrappers['main'].batch_size) for j in range(self.z_values.size): - tzj = np.fmax(np.fmin(batch.rewards() + - (1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j], - self.z_values[self.z_values.size - 1]), - self.z_values[0]) + # we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step, + # we will not bootstrap for the last n-step transitions in the episode + tzj = np.fmax(np.fmin(batch.n_step_discounted_rewards() + batch.info('should_bootstrap_next_state') * + (self.ap.algorithm.discount ** self.ap.algorithm.n_step) * self.z_values[j], + self.z_values[-1]), self.z_values[0]) bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0]) u = (np.ceil(bj)).astype(int) l = (np.floor(bj)).astype(int) - m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj)) - m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l)) + m[batches, l] += (distributional_q_st_plus_n[batches, target_actions, j] * (u - bj)) + m[batches, u] += (distributional_q_st_plus_n[batches, target_actions, j] * (bj - l)) # total_loss = cross entropy between actual result above and predicted result for the given action TD_targets[batches, batch.actions()] = m @@ -122,7 +115,7 @@ class RainbowDQNAgent(CategoricalDQNAgent): # TODO: fix this spaghetti code if isinstance(self.memory, PrioritizedExperienceReplay): errors = losses[0][np.arange(batch.size), batch.actions()] - self.memory.update_priorities(batch.info('idx'), errors) + self.call_memory('update_priorities', (batch.info('idx'), errors)) return total_loss, losses, unclipped_grads diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index 03dbf25..03ed774 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -166,6 +166,9 @@ class AlgorithmParameters(Parameters): # intrinsic reward self.scale_external_reward_by_intrinsic_reward_value = False + # n-step returns + self.n_step = -1 # calculate the total return (no bootstrap, by default) + # Distributed Coach params self.distributed_coach_synchronization_type = None diff --git a/rl_coach/core_types.py b/rl_coach/core_types.py index 6d6bfe4..5180bc1 100644 --- a/rl_coach/core_types.py +++ b/rl_coach/core_types.py @@ -125,6 +125,7 @@ class Middleware_LSTM_Embedding(MiddlewareEmbedding): class Measurements(PredictionType): pass + PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames] @@ -162,7 +163,7 @@ class Transition(object): self._state = self.state = state self._action = self.action = action self._reward = self.reward = reward - self._total_return = self.total_return = None + self._n_step_discounted_rewards = self.n_step_discounted_rewards = None if not next_state: next_state = state self._next_state = self._next_state = next_state @@ -207,15 +208,15 @@ class Transition(object): self._reward = val @property - def total_return(self): - if self._total_return is None: - raise Exception("The total_return was not filled by any of the modules between the environment and the " - "agent. Make sure that you are using an episodic experience replay.") - return self._total_return + def n_step_discounted_rewards(self): + if self._n_step_discounted_rewards is None: + raise Exception("The n_step_discounted_rewards were not filled by any of the modules between the " + "environment and the agent. Make sure that you are using an episodic experience replay.") + return self._n_step_discounted_rewards - @total_return.setter - def total_return(self, val): - self._total_return = val + @n_step_discounted_rewards.setter + def n_step_discounted_rewards(self, val): + self._n_step_discounted_rewards = val @property def game_over(self): @@ -322,6 +323,7 @@ class ActionInfo(object): """ Action info is a class that holds an action and various additional information details about it """ + def __init__(self, action: ActionType, action_probability: float=0, action_value: float=0., state_value: float=0., max_action_value: float=None, action_intrinsic_reward: float=0): @@ -359,7 +361,7 @@ class Batch(object): self._states = {} self._actions = None self._rewards = None - self._total_returns = None + self._n_step_discounted_rewards = None self._game_overs = None self._next_states = {} self._goals = None @@ -380,8 +382,8 @@ class Batch(object): self._actions = self._actions[start:end] if self._rewards is not None: self._rewards = self._rewards[start:end] - if self._total_returns is not None: - self._total_returns = self._total_returns[start:end] + if self._n_step_discounted_rewards is not None: + self._n_step_discounted_rewards = self._n_step_discounted_rewards[start:end] if self._game_overs is not None: self._game_overs = self._game_overs[start:end] for k, v in self._next_states.items(): @@ -402,7 +404,7 @@ class Batch(object): self._states = {} self._actions = None self._rewards = None - self._total_returns = None + self._n_step_discounted_rewards = None self._game_overs = None self._next_states = {} self._goals = None @@ -471,18 +473,20 @@ class Batch(object): return np.expand_dims(self._rewards, -1) return self._rewards - def total_returns(self, expand_dims=False) -> np.ndarray: + def n_step_discounted_rewards(self, expand_dims=False) -> np.ndarray: """ - if the total_returns were not converted to a batch before, extract them to a batch and then return the batch - if the total return was not filled, this will raise an exception + if the n_step_discounted_rewards were not converted to a batch before, extract them to a batch and then return + the batch + if the n step discounted rewards were not filled, this will raise an exception :param expand_dims: add an extra dimension to the total_returns batch :return: a numpy array containing all the total return values of the batch """ - if self._total_returns is None: - self._total_returns = np.array([transition.total_return for transition in self.transitions]) + if self._n_step_discounted_rewards is None: + self._n_step_discounted_rewards = np.array([transition.n_step_discounted_rewards for transition in + self.transitions]) if expand_dims: - return np.expand_dims(self._total_returns, -1) - return self._total_returns + return np.expand_dims(self._n_step_discounted_rewards, -1) + return self._n_step_discounted_rewards def game_overs(self, expand_dims=False) -> np.ndarray: """ @@ -510,7 +514,8 @@ class Batch(object): # addition to the current_state, so that all the inputs of the network will be filled) for key in set(fetches).intersection(self.transitions[0].next_state.keys()): if key not in self._next_states.keys(): - self._next_states[key] = np.array([np.array(transition.next_state[key]) for transition in self.transitions]) + self._next_states[key] = np.array( + [np.array(transition.next_state[key]) for transition in self.transitions]) if expand_dims: next_states[key] = np.expand_dims(self._next_states[key], -1) else: @@ -530,6 +535,16 @@ class Batch(object): return np.expand_dims(self._goals, -1) return self._goals + def info_as_list(self, key) -> list: + """ + get the info and store it internally as a list, if wasn't stored before. return it as a list + :param expand_dims: add an extra dimension to the info batch + :return: a list containing all the info values of the batch corresponding to the given key + """ + if key not in self._info.keys(): + self._info[key] = [transition.info[key] for transition in self.transitions] + return self._info[key] + def info(self, key, expand_dims=False) -> np.ndarray: """ if the given info dictionary key was not converted to a batch before, extract it to a batch and then return the @@ -537,11 +552,11 @@ class Batch(object): :param expand_dims: add an extra dimension to the info batch :return: a numpy array containing all the info values of the batch corresponding to the given key """ - if key not in self._info.keys(): - self._info[key] = np.array([transition.info[key] for transition in self.transitions]) + info_list = self.info_as_list(key) + if expand_dims: - return np.expand_dims(self._info[key], -1) - return self._info[key] + return np.expand_dims(info_list, -1) + return np.array(info_list) @property def size(self) -> int: @@ -572,6 +587,7 @@ class TotalStepsCounter(object): """ A wrapper around a dictionary counting different StepMethods steps done. """ + def __init__(self): self.counters = { EnvironmentEpisodes: 0, @@ -619,7 +635,6 @@ class Episode(object): """ self.transitions = [] # a num_transitions x num_transitions table with the n step return in the n'th row - self.returns_table = None self._length = 0 self.discount = discount self.bootstrap_total_return_from_old_policy = bootstrap_total_return_from_old_policy @@ -650,28 +665,48 @@ class Episode(object): def get_first_transition(self): return self.get_transition(0) if self.length() > 0 else None - def update_returns(self): + def update_discounted_rewards(self): if self.n_step == -1 or self.n_step > self.length(): curr_n_step = self.length() else: curr_n_step = self.n_step + rewards = np.array([t.reward for t in self.transitions]) rewards = rewards.astype('float') - total_return = rewards.copy() + discounted_rewards = rewards.copy() current_discount = self.discount for i in range(1, curr_n_step): - total_return += current_discount * np.pad(rewards[i:], (0, i), 'constant', constant_values=0) + discounted_rewards += current_discount * np.pad(rewards[i:], (0, i), 'constant', constant_values=0) current_discount *= self.discount # calculate the bootstrapped returns if self.bootstrap_total_return_from_old_policy: bootstraps = np.array([np.squeeze(t.info['max_action_value']) for t in self.transitions[curr_n_step:]]) - bootstrapped_return = total_return + current_discount * np.pad(bootstraps, (0, curr_n_step), 'constant', - constant_values=0) - total_return = bootstrapped_return + bootstrapped_return = discounted_rewards + current_discount * np.pad(bootstraps, (0, curr_n_step), + 'constant', constant_values=0) + discounted_rewards = bootstrapped_return for transition_idx in range(self.length()): - self.transitions[transition_idx].total_return = total_return[transition_idx] + self.transitions[transition_idx].n_step_discounted_rewards = discounted_rewards[transition_idx] + + def update_transitions_rewards_and_bootstrap_data(self): + if not isinstance(self.n_step, int) or (self.n_step < 1 and self.n_step != -1): + raise ValueError("n-step should be an integer with value >= 1, or set to -1 for always setting to episode" + " length.") + elif self.n_step > 1: + curr_n_step = self.n_step if self.n_step < self.length() else self.length() + + for idx, transition in enumerate(self.transitions): + next_n_step_transition_idx = (idx + curr_n_step) + if next_n_step_transition_idx < len(self.transitions): + # next state will now point to the n-step next state + transition.next_state = self.transitions[next_n_step_transition_idx].state + transition.info['should_bootstrap_next_state'] = True + else: + transition.next_state = self.transitions[-1].next_state + transition.info['should_bootstrap_next_state'] = False + + self.update_discounted_rewards() def update_actions_probabilities(self): probability_product = 1 @@ -681,12 +716,6 @@ class Episode(object): for transition_idx, transition in enumerate(self.transitions): transition.info['probability_product'] = probability_product - def get_returns_table(self): - return self.returns_table - - def get_returns(self): - return self.get_transitions_attribute('total_return') - def get_transitions_attribute(self, attribute_name): if len(self.transitions) > 0 and hasattr(self.transitions[0], attribute_name): return [getattr(t, attribute_name) for t in self.transitions] diff --git a/rl_coach/memories/episodic/episodic_experience_replay.py b/rl_coach/memories/episodic/episodic_experience_replay.py index e3d2eb2..2f4b393 100644 --- a/rl_coach/memories/episodic/episodic_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_experience_replay.py @@ -27,6 +27,7 @@ class EpisodicExperienceReplayParameters(MemoryParameters): def __init__(self): super().__init__() self.max_size = (MemoryGranularity.Transitions, 1000000) + self.n_step = -1 @property def path(self): @@ -39,12 +40,13 @@ class EpisodicExperienceReplay(Memory): calculations of total return and other values that depend on the sequential behavior of the transitions in the episode. """ - def __init__(self, max_size: Tuple[MemoryGranularity, int]): + def __init__(self, max_size: Tuple[MemoryGranularity, int]=(MemoryGranularity.Transitions, 1000000), n_step=-1): """ :param max_size: the maximum number of transitions or episodes to hold in the memory """ super().__init__(max_size) - self._buffer = [Episode()] # list of episodes + self.n_step = n_step + self._buffer = [Episode(n_step=self.n_step)] # list of episodes self.transitions = [] self._length = 1 # the episodic replay buffer starts with a single empty episode self._num_transitions = 0 @@ -109,7 +111,7 @@ class EpisodicExperienceReplay(Memory): self._remove_episode(0) def _update_episode(self, episode: Episode) -> None: - episode.update_returns() + episode.update_transitions_rewards_and_bootstrap_data() def verify_last_episode_is_closed(self) -> None: """ @@ -138,7 +140,7 @@ class EpisodicExperienceReplay(Memory): self._length += 1 # create a new Episode for the next transitions to be placed into - self._buffer.append(Episode()) + self._buffer.append(Episode(n_step=self.n_step)) # if update episode adds to the buffer, a new Episode needs to be ready first # it would be better if this were less state full @@ -158,12 +160,14 @@ class EpisodicExperienceReplay(Memory): :param transition: a transition to store :return: None """ + # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. super().store(transition) + self.reader_writer_lock.lock_writing_and_reading() if len(self._buffer) == 0: - self._buffer.append(Episode()) + self._buffer.append(Episode(n_step=self.n_step)) last_episode = self._buffer[-1] last_episode.insert(transition) self.transitions.append(transition) @@ -284,7 +288,7 @@ class EpisodicExperienceReplay(Memory): self.reader_writer_lock.lock_writing_and_reading() self.transitions = [] - self._buffer = [Episode()] + self._buffer = [Episode(n_step=self.n_step)] self._length = 1 self._num_transitions = 0 self._num_transitions_in_complete_episodes = 0 diff --git a/rl_coach/memories/episodic/episodic_hindsight_experience_replay.py b/rl_coach/memories/episodic/episodic_hindsight_experience_replay.py index c30f451..69468a1 100644 --- a/rl_coach/memories/episodic/episodic_hindsight_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_hindsight_experience_replay.py @@ -139,7 +139,7 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay): hindsight_transition.reward, hindsight_transition.game_over = \ self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state) - hindsight_transition.total_return = None + hindsight_transition.n_step_discounted_rewards = None episode.insert(hindsight_transition) super().store_episode(episode) diff --git a/rl_coach/memories/non_episodic/prioritized_experience_replay.py b/rl_coach/memories/non_episodic/prioritized_experience_replay.py index 544df48..120a738 100644 --- a/rl_coach/memories/non_episodic/prioritized_experience_replay.py +++ b/rl_coach/memories/non_episodic/prioritized_experience_replay.py @@ -128,14 +128,14 @@ class SegmentTree(object): self.tree[node_idx] = new_val self._propagate(node_idx) - def get(self, val: float) -> Tuple[int, float, Any]: + def get_element_by_partial_sum(self, val: float) -> Tuple[int, float, Any]: """ Given a value between 0 and the tree sum, return the object which this value is in it's range. For example, if we have 3 leaves: 10, 20, 30, and val=35, this will return the 3rd leaf, by accumulating leaves by their order until getting to 35. This allows sampling leaves according to their proportional probability. :param val: a value within the range 0 and the tree sum - :return: the index of the resulting leaf in the tree, it's probability and + :return: the index of the resulting leaf in the tree, its probability and the object itself """ node_idx = self._retrieve(0, val) @@ -237,12 +237,12 @@ class PrioritizedExperienceReplay(ExperienceReplay): # sample a batch for i in range(size): - start_probability = segment_size * i - end_probability = segment_size * (i + 1) + segment_start = segment_size * i + segment_end = segment_size * (i + 1) # sample leaf and calculate its weight - val = random.uniform(start_probability, end_probability) - leaf_idx, priority, transition = self.sum_tree.get(val) + val = random.uniform(segment_start, segment_end) + leaf_idx, priority, transition = self.sum_tree.get_element_by_partial_sum(val) priority /= self.sum_tree.total_value() # P(j) = p^a / sum(p^a) weight = (self.num_transitions() * priority) ** -self.beta.current_value # (N * P(j)) ^ -beta normalized_weight = weight / max_weight # wj = ((N * P(j)) ^ -beta) / max wi @@ -261,7 +261,7 @@ class PrioritizedExperienceReplay(ExperienceReplay): self.reader_writer_lock.release_writing() return batch - def store(self, transition: Transition) -> None: + def store(self, transition: Transition, lock=True) -> None: """ Store a new transition in the memory. :param transition: a transition to store @@ -270,7 +270,8 @@ class PrioritizedExperienceReplay(ExperienceReplay): # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. super().store(transition) - self.reader_writer_lock.lock_writing_and_reading() + if lock: + self.reader_writer_lock.lock_writing_and_reading() transition_priority = self.maximal_priority self.sum_tree.add(transition_priority ** self.alpha, transition) @@ -278,18 +279,21 @@ class PrioritizedExperienceReplay(ExperienceReplay): self.max_tree.add(transition_priority, transition) super().store(transition, False) - self.reader_writer_lock.release_writing_and_reading() + if lock: + self.reader_writer_lock.release_writing_and_reading() - def clean(self) -> None: + def clean(self, lock=True) -> None: """ Clean the memory by removing all the episodes :return: None """ - self.reader_writer_lock.lock_writing_and_reading() + if lock: + self.reader_writer_lock.lock_writing_and_reading() super().clean(lock=False) self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM) self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN) self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX) - self.reader_writer_lock.release_writing_and_reading() + if lock: + self.reader_writer_lock.release_writing_and_reading() diff --git a/rl_coach/presets/CartPole_PPO.py b/rl_coach/presets/CartPole_ClippedPPO.py similarity index 98% rename from rl_coach/presets/CartPole_PPO.py rename to rl_coach/presets/CartPole_ClippedPPO.py index 0c13abb..7c4d3c1 100644 --- a/rl_coach/presets/CartPole_PPO.py +++ b/rl_coach/presets/CartPole_ClippedPPO.py @@ -63,7 +63,7 @@ env_params = GymVectorEnvironment(level='CartPole-v0') preset_validation_params = PresetValidationParameters() preset_validation_params.test = True preset_validation_params.min_reward_threshold = 150 -preset_validation_params.max_episodes_to_achieve_reward = 250 +preset_validation_params.max_episodes_to_achieve_reward = 400 graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params, vis_params=VisualizationParameters(), diff --git a/rl_coach/tests/memories/test_prioritized_experience_replay.py b/rl_coach/tests/memories/test_prioritized_experience_replay.py index 020c4a1..51f1a61 100644 --- a/rl_coach/tests/memories/test_prioritized_experience_replay.py +++ b/rl_coach/tests/memories/test_prioritized_experience_replay.py @@ -26,10 +26,10 @@ def test_sum_tree(): sum_tree.add(5, "5") assert sum_tree.total_value() == 20 - assert sum_tree.get(2) == (0, 2.5, '2.5') - assert sum_tree.get(3) == (1, 5.0, '5') - assert sum_tree.get(10) == (2, 5.0, '5') - assert sum_tree.get(13) == (3, 7.5, '7.5') + assert sum_tree.get_element_by_partial_sum(2) == (0, 2.5, '2.5') + assert sum_tree.get_element_by_partial_sum(3) == (1, 5.0, '5') + assert sum_tree.get_element_by_partial_sum(10) == (2, 5.0, '5') + assert sum_tree.get_element_by_partial_sum(13) == (3, 7.5, '7.5') sum_tree.update(2, 10) assert sum_tree.__str__() == "[25.]\n[ 7.5 17.5]\n[ 2.5 5. 10. 7.5]\n" diff --git a/rl_coach/tests/memories/test_single_episode_buffer.py b/rl_coach/tests/memories/test_single_episode_buffer.py index c2f12ab..84e1e26 100644 --- a/rl_coach/tests/memories/test_single_episode_buffer.py +++ b/rl_coach/tests/memories/test_single_episode_buffer.py @@ -41,8 +41,8 @@ def test_store_and_get(buffer: SingleEpisodeBuffer): # check that the episode is valid episode = buffer.get(0) assert episode.length() == 2 - assert episode.get_transition(0).total_return == 1 + 0.99 - assert episode.get_transition(1).total_return == 1 + assert episode.get_transition(0).n_step_discounted_rewards == 1 + 0.99 + assert episode.get_transition(1).n_step_discounted_rewards == 1 assert buffer.mean_reward() == 1 # only one episode in the replay buffer