1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

N-step returns for rainbow (#67)

* n_step returns for rainbow
* Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
Gal Leibovich
2018-11-07 18:33:08 +02:00
committed by GitHub
parent 35c477c922
commit 49dea39d34
18 changed files with 173 additions and 117 deletions

View File

@@ -116,7 +116,6 @@ class Agent(AgentInterface):
self.output_filter.set_device(device) self.output_filter.set_device(device)
self.pre_network_filter.set_device(device) self.pre_network_filter.set_device(device)
# initialize all internal variables # initialize all internal variables
self._phase = RunPhase.HEATUP self._phase = RunPhase.HEATUP
self.total_shaped_reward_in_current_episode = 0 self.total_shaped_reward_in_current_episode = 0
@@ -143,7 +142,7 @@ class Agent(AgentInterface):
self.accumulated_shaped_rewards_across_evaluation_episodes = 0 self.accumulated_shaped_rewards_across_evaluation_episodes = 0
self.num_successes_across_evaluation_episodes = 0 self.num_successes_across_evaluation_episodes = 0
self.num_evaluation_episodes_completed = 0 self.num_evaluation_episodes_completed = 0
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount) self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount, n_step=self.ap.algorithm.n_step)
# TODO: add agents observation rendering for debugging purposes (not the same as the environment rendering) # TODO: add agents observation rendering for debugging purposes (not the same as the environment rendering)
# environment parameters # environment parameters
@@ -452,10 +451,10 @@ class Agent(AgentInterface):
:return: None :return: None
""" """
self.current_episode_buffer.is_complete = True self.current_episode_buffer.is_complete = True
self.current_episode_buffer.update_returns() self.current_episode_buffer.update_transitions_rewards_and_bootstrap_data()
for transition in self.current_episode_buffer.transitions: for transition in self.current_episode_buffer.transitions:
self.discounted_return.add_sample(transition.total_return) self.discounted_return.add_sample(transition.n_step_discounted_rewards)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only: if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
self.current_episode += 1 self.current_episode += 1
@@ -497,7 +496,7 @@ class Agent(AgentInterface):
self.curr_state = {} self.curr_state = {}
self.current_episode_steps_counter = 0 self.current_episode_steps_counter = 0
self.episode_running_info = {} self.episode_running_info = {}
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount) self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount, n_step=self.ap.algorithm.n_step)
if self.exploration_policy: if self.exploration_policy:
self.exploration_policy.reset() self.exploration_policy.reset()
self.input_filter.reset() self.input_filter.reset()

View File

@@ -17,14 +17,11 @@
from typing import Union from typing import Union
import numpy as np import numpy as np
from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters, DQNAgentParameters from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters, DQNAgentParameters
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.architectures.head_parameters import CategoricalQHeadParameters from rl_coach.architectures.head_parameters import CategoricalQHeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import StateType from rl_coach.core_types import StateType
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
from rl_coach.schedules import LinearSchedule from rl_coach.schedules import LinearSchedule
@@ -85,28 +82,47 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# for the action we actually took, the error is calculated by the atoms distribution # for the action we actually took, the error is calculated by the atoms distribution
# for all other actions, the error is 0 # for all other actions, the error is 0
distributed_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([ distributional_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([
(self.networks['main'].target_network, batch.next_states(network_keys)), (self.networks['main'].target_network, batch.next_states(network_keys)),
(self.networks['main'].online_network, batch.states(network_keys)) (self.networks['main'].online_network, batch.states(network_keys))
]) ])
# only update the action that we have actually done in this transition # select the optimal actions for the next state
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributed_q_st_plus_1), axis=1) target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
batches = np.arange(self.ap.network_wrappers['main'].batch_size) batches = np.arange(self.ap.network_wrappers['main'].batch_size)
# an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.
# only 10% speedup overall - leaving commented out as the code is not as clear.
# tzj_ = np.fmax(np.fmin(batch.rewards() + (1.0 - batch.game_overs()) * self.ap.algorithm.discount *
# np.transpose(np.repeat(self.z_values[np.newaxis, :], batch.size, axis=0), (1, 0)),
# self.z_values[-1]),
# self.z_values[0])
#
# bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])
# u_ = (np.ceil(bj_)).astype(int)
# l_ = (np.floor(bj_)).astype(int)
# m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
# np.add.at(m_, [batches, l_],
# np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))
# np.add.at(m_, [batches, u_],
# np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (bj_ - l_))
for j in range(self.z_values.size): for j in range(self.z_values.size):
tzj = np.fmax(np.fmin(batch.rewards() + tzj = np.fmax(np.fmin(batch.rewards() +
(1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j], (1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j],
self.z_values[self.z_values.size - 1]), self.z_values[-1]),
self.z_values[0]) self.z_values[0])
bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0]) bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])
u = (np.ceil(bj)).astype(int) u = (np.ceil(bj)).astype(int)
l = (np.floor(bj)).astype(int) l = (np.floor(bj)).astype(int)
m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj)) m[batches, l] += (distributional_q_st_plus_1[batches, target_actions, j] * (u - bj))
m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l)) m[batches, u] += (distributional_q_st_plus_1[batches, target_actions, j] * (bj - l))
# total_loss = cross entropy between actual result above and predicted result for the given action # total_loss = cross entropy between actual result above and predicted result for the given action
# only update the action that we have actually done in this transition
TD_targets[batches, batch.actions()] = m TD_targets[batches, batch.actions()] = m
# update errors in prioritized replay buffer # update errors in prioritized replay buffer
@@ -120,7 +136,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# TODO: fix this spaghetti code # TODO: fix this spaghetti code
if isinstance(self.memory, PrioritizedExperienceReplay): if isinstance(self.memory, PrioritizedExperienceReplay):
errors = losses[0][np.arange(batch.size), batch.actions()] errors = losses[0][np.arange(batch.size), batch.actions()]
self.memory.update_priorities(batch.info('idx'), errors) self.call_memory('update_priorities', (batch.info('idx'), errors))
return total_loss, losses, unclipped_grads return total_loss, losses, unclipped_grads

View File

@@ -116,8 +116,10 @@ class ClippedPPOAgent(ActorCriticAgent):
# calculate advantages # calculate advantages
advantages = [] advantages = []
value_targets = [] value_targets = []
total_returns = batch.n_step_discounted_rewards()
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE: if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
advantages = batch.total_returns() - current_state_values advantages = total_returns - current_state_values
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE: elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
# get bootstraps # get bootstraps
episode_start_idx = 0 episode_start_idx = 0
@@ -181,11 +183,13 @@ class ClippedPPOAgent(ActorCriticAgent):
result = self.networks['main'].target_network.predict({k: v[start:end] for k, v in batch.states(network_keys).items()}) result = self.networks['main'].target_network.predict({k: v[start:end] for k, v in batch.states(network_keys).items()})
old_policy_distribution = result[1:] old_policy_distribution = result[1:]
total_returns = batch.n_step_discounted_rewards(expand_dims=True)
# calculate gradients and apply on both the local policy network and on the global policy network # calculate gradients and apply on both the local policy network and on the global policy network
if self.ap.algorithm.estimate_state_value_using_gae: if self.ap.algorithm.estimate_state_value_using_gae:
value_targets = np.expand_dims(gae_based_value_targets, -1) value_targets = np.expand_dims(gae_based_value_targets, -1)
else: else:
value_targets = batch.total_returns(expand_dims=True)[start:end] value_targets = total_returns[start:end]
inputs = copy.copy({k: v[start:end] for k, v in batch.states(network_keys).items()}) inputs = copy.copy({k: v[start:end] for k, v in batch.states(network_keys).items()})
inputs['output_1_0'] = actions inputs['output_1_0'] = actions

View File

@@ -58,11 +58,13 @@ class MixedMonteCarloAgent(ValueOptimizationAgent):
(self.networks['main'].online_network, batch.states(network_keys)) (self.networks['main'].online_network, batch.states(network_keys))
]) ])
total_returns = batch.n_step_discounted_rewards()
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(self.ap.network_wrappers['main'].batch_size):
one_step_target = batch.rewards()[i] + \ one_step_target = batch.rewards()[i] + \
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
q_st_plus_1[i][selected_actions[i]] q_st_plus_1[i][selected_actions[i]]
monte_carlo_target = batch.total_returns()[i] monte_carlo_target = total_returns()[i]
TD_targets[i, batch.actions()[i]] = (1 - self.mixing_rate) * one_step_target + \ TD_targets[i, batch.actions()[i]] = (1 - self.mixing_rate) * one_step_target + \
self.mixing_rate * monte_carlo_target self.mixing_rate * monte_carlo_target

View File

@@ -98,10 +98,10 @@ class NECAgent(ValueOptimizationAgent):
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys)) TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys))
bootstrapped_return_from_old_policy = batch.n_step_discounted_rewards()
# only update the action that we have actually done in this transition # only update the action that we have actually done in this transition
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(self.ap.network_wrappers['main'].batch_size):
TD_targets[i, batch.actions()[i]] = batch.total_returns()[i] TD_targets[i, batch.actions()[i]] = bootstrapped_return_from_old_policy[i]
# set the gradients to fetch for the DND update # set the gradients to fetch for the DND update
fetches = [] fetches = []
@@ -165,10 +165,10 @@ class NECAgent(ValueOptimizationAgent):
episode = self.call_memory('get_last_complete_episode') episode = self.call_memory('get_last_complete_episode')
if episode is not None and self.phase != RunPhase.TEST: if episode is not None and self.phase != RunPhase.TEST:
assert len(self.current_episode_state_embeddings) == episode.length() assert len(self.current_episode_state_embeddings) == episode.length()
returns = episode.get_transitions_attribute('total_return') discounted_rewards = episode.get_transitions_attribute('n_step_discounted_rewards')
actions = episode.get_transitions_attribute('action') actions = episode.get_transitions_attribute('action')
self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings, self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings,
actions, returns) actions, discounted_rewards)
def save_checkpoint(self, checkpoint_id): def save_checkpoint(self, checkpoint_id):
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f: with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:

View File

@@ -70,6 +70,7 @@ class PALAgent(ValueOptimizationAgent):
# calculate TD error # calculate TD error
TD_targets = np.copy(q_st_online) TD_targets = np.copy(q_st_online)
total_returns = batch.n_step_discounted_rewards()
for i in range(self.ap.network_wrappers['main'].batch_size): for i in range(self.ap.network_wrappers['main'].batch_size):
TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \ TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \ (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
@@ -83,7 +84,7 @@ class PALAgent(ValueOptimizationAgent):
TD_targets[i, batch.actions()[i]] -= self.alpha * advantage_learning_update TD_targets[i, batch.actions()[i]] -= self.alpha * advantage_learning_update
# mixing monte carlo updates # mixing monte carlo updates
monte_carlo_target = batch.total_returns()[i] monte_carlo_target = total_returns[i]
TD_targets[i, batch.actions()[i]] = (1 - self.monte_carlo_mixing_rate) * TD_targets[i, batch.actions()[i]] \ TD_targets[i, batch.actions()[i]] = (1 - self.monte_carlo_mixing_rate) * TD_targets[i, batch.actions()[i]] \
+ self.monte_carlo_mixing_rate * monte_carlo_target + self.monte_carlo_mixing_rate * monte_carlo_target

View File

@@ -74,7 +74,7 @@ class PolicyGradientsAgent(PolicyOptimizationAgent):
# batch contains a list of episodes to learn from # batch contains a list of episodes to learn from
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
total_returns = batch.total_returns() total_returns = batch.n_step_discounted_rewards()
for i in reversed(range(batch.size)): for i in reversed(range(batch.size)):
if self.policy_gradient_rescaler == PolicyGradientRescaler.TOTAL_RETURN: if self.policy_gradient_rescaler == PolicyGradientRescaler.TOTAL_RETURN:
total_returns[i] = total_returns[0] total_returns[i] = total_returns[0]

View File

@@ -73,11 +73,11 @@ class PolicyOptimizationAgent(Agent):
episode_discounted_returns = [] episode_discounted_returns = []
for i in range(episode.length()): for i in range(episode.length()):
transition = episode.get_transition(i) transition = episode.get_transition(i)
episode_discounted_returns.append(transition.total_return) episode_discounted_returns.append(transition.n_step_discounted_rewards)
self.num_episodes_where_step_has_been_seen[i] += 1 self.num_episodes_where_step_has_been_seen[i] += 1
self.mean_return_over_multiple_episodes[i] -= self.mean_return_over_multiple_episodes[i] / \ self.mean_return_over_multiple_episodes[i] -= self.mean_return_over_multiple_episodes[i] / \
self.num_episodes_where_step_has_been_seen[i] self.num_episodes_where_step_has_been_seen[i]
self.mean_return_over_multiple_episodes[i] += transition.total_return / \ self.mean_return_over_multiple_episodes[i] += transition.n_step_discounted_rewards / \
self.num_episodes_where_step_has_been_seen[i] self.num_episodes_where_step_has_been_seen[i]
self.mean_discounted_return = np.mean(episode_discounted_returns) self.mean_discounted_return = np.mean(episode_discounted_returns)
self.std_discounted_return = np.std(episode_discounted_returns) self.std_discounted_return = np.std(episode_discounted_returns)
@@ -97,7 +97,7 @@ class PolicyOptimizationAgent(Agent):
network.set_is_training(True) network.set_is_training(True)
# we need to update the returns of the episode until now # we need to update the returns of the episode until now
episode.update_returns() episode.update_transitions_rewards_and_bootstrap_data()
# get t_max transitions or less if the we got to a terminal state # get t_max transitions or less if the we got to a terminal state
# will be used for both actor-critic and vanilla PG. # will be used for both actor-critic and vanilla PG.

View File

@@ -112,11 +112,11 @@ class PPOAgent(ActorCriticAgent):
# current_states_with_timestep = self.concat_state_and_timestep(batch) # current_states_with_timestep = self.concat_state_and_timestep(batch)
current_state_values = self.networks['critic'].online_network.predict(batch.states(network_keys)).squeeze() current_state_values = self.networks['critic'].online_network.predict(batch.states(network_keys)).squeeze()
total_returns = batch.n_step_discounted_rewards()
# calculate advantages # calculate advantages
advantages = [] advantages = []
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE: if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
advantages = batch.total_returns() - current_state_values advantages = total_returns - current_state_values
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE: elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
# get bootstraps # get bootstraps
episode_start_idx = 0 episode_start_idx = 0
@@ -155,6 +155,7 @@ class PPOAgent(ActorCriticAgent):
# current_states_with_timestep = self.concat_state_and_timestep(dataset) # current_states_with_timestep = self.concat_state_and_timestep(dataset)
mix_fraction = self.ap.algorithm.value_targets_mix_fraction mix_fraction = self.ap.algorithm.value_targets_mix_fraction
total_returns = batch.n_step_discounted_rewards(True)
for j in range(epochs): for j in range(epochs):
curr_batch_size = batch.size curr_batch_size = batch.size
if self.networks['critic'].online_network.optimizer_type != 'LBFGS': if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
@@ -165,7 +166,7 @@ class PPOAgent(ActorCriticAgent):
k: v[i * curr_batch_size:(i + 1) * curr_batch_size] k: v[i * curr_batch_size:(i + 1) * curr_batch_size]
for k, v in batch.states(network_keys).items() for k, v in batch.states(network_keys).items()
} }
total_return_batch = batch.total_returns(True)[i * curr_batch_size:(i + 1) * curr_batch_size] total_return_batch = total_returns[i * curr_batch_size:(i + 1) * curr_batch_size]
old_policy_values = force_list(self.networks['critic'].target_network.predict( old_policy_values = force_list(self.networks['critic'].target_network.predict(
current_states_batch).squeeze()) current_states_batch).squeeze())
if self.networks['critic'].online_network.optimizer_type != 'LBFGS': if self.networks['critic'].online_network.optimizer_type != 'LBFGS':

View File

@@ -39,23 +39,17 @@ class RainbowDQNNetworkParameters(DQNNetworkParameters):
class RainbowDQNAlgorithmParameters(CategoricalDQNAlgorithmParameters): class RainbowDQNAlgorithmParameters(CategoricalDQNAlgorithmParameters):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.n_step = 3
# needed for n-step updates to work. i.e. waiting for a full episode to be closed before storing each transition
class RainbowDQNExplorationParameters(ParameterNoiseParameters): self.store_transitions_only_when_episodes_are_terminated = True
def __init__(self, agent_params):
super().__init__(agent_params)
class RainbowDQNMemoryParameters(PrioritizedExperienceReplayParameters):
def __init__(self):
super().__init__()
class RainbowDQNAgentParameters(CategoricalDQNAgentParameters): class RainbowDQNAgentParameters(CategoricalDQNAgentParameters):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.algorithm = RainbowDQNAlgorithmParameters() self.algorithm = RainbowDQNAlgorithmParameters()
self.exploration = RainbowDQNExplorationParameters(self) self.exploration = ParameterNoiseParameters(self)
self.memory = PrioritizedExperienceReplayParameters() self.memory = PrioritizedExperienceReplayParameters()
self.network_wrappers = {"main": RainbowDQNNetworkParameters()} self.network_wrappers = {"main": RainbowDQNNetworkParameters()}
@@ -65,15 +59,13 @@ class RainbowDQNAgentParameters(CategoricalDQNAgentParameters):
# Rainbow Deep Q Network - https://arxiv.org/abs/1710.02298 # Rainbow Deep Q Network - https://arxiv.org/abs/1710.02298
# Agent implementation is WIP. Currently is composed of: # Agent implementation is composed of:
# 1. NoisyNets # 1. NoisyNets
# 2. C51 # 2. C51
# 3. Prioritized ER # 3. Prioritized ER
# 4. DDQN # 4. DDQN
# 5. Dueling DQN # 5. Dueling DQN
# # 6. N-step returns
# still missing:
# 1. N-Step
class RainbowDQNAgent(CategoricalDQNAgent): class RainbowDQNAgent(CategoricalDQNAgent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None): def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
@@ -87,7 +79,7 @@ class RainbowDQNAgent(CategoricalDQNAgent):
# for the action we actually took, the error is calculated by the atoms distribution # for the action we actually took, the error is calculated by the atoms distribution
# for all other actions, the error is 0 # for all other actions, the error is 0
distributed_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([ distributional_q_st_plus_n, TD_targets = self.networks['main'].parallel_prediction([
(self.networks['main'].target_network, batch.next_states(network_keys)), (self.networks['main'].target_network, batch.next_states(network_keys)),
(self.networks['main'].online_network, batch.states(network_keys)) (self.networks['main'].online_network, batch.states(network_keys))
]) ])
@@ -98,15 +90,16 @@ class RainbowDQNAgent(CategoricalDQNAgent):
batches = np.arange(self.ap.network_wrappers['main'].batch_size) batches = np.arange(self.ap.network_wrappers['main'].batch_size)
for j in range(self.z_values.size): for j in range(self.z_values.size):
tzj = np.fmax(np.fmin(batch.rewards() + # we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step,
(1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j], # we will not bootstrap for the last n-step transitions in the episode
self.z_values[self.z_values.size - 1]), tzj = np.fmax(np.fmin(batch.n_step_discounted_rewards() + batch.info('should_bootstrap_next_state') *
self.z_values[0]) (self.ap.algorithm.discount ** self.ap.algorithm.n_step) * self.z_values[j],
self.z_values[-1]), self.z_values[0])
bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0]) bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])
u = (np.ceil(bj)).astype(int) u = (np.ceil(bj)).astype(int)
l = (np.floor(bj)).astype(int) l = (np.floor(bj)).astype(int)
m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj)) m[batches, l] += (distributional_q_st_plus_n[batches, target_actions, j] * (u - bj))
m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l)) m[batches, u] += (distributional_q_st_plus_n[batches, target_actions, j] * (bj - l))
# total_loss = cross entropy between actual result above and predicted result for the given action # total_loss = cross entropy between actual result above and predicted result for the given action
TD_targets[batches, batch.actions()] = m TD_targets[batches, batch.actions()] = m
@@ -122,7 +115,7 @@ class RainbowDQNAgent(CategoricalDQNAgent):
# TODO: fix this spaghetti code # TODO: fix this spaghetti code
if isinstance(self.memory, PrioritizedExperienceReplay): if isinstance(self.memory, PrioritizedExperienceReplay):
errors = losses[0][np.arange(batch.size), batch.actions()] errors = losses[0][np.arange(batch.size), batch.actions()]
self.memory.update_priorities(batch.info('idx'), errors) self.call_memory('update_priorities', (batch.info('idx'), errors))
return total_loss, losses, unclipped_grads return total_loss, losses, unclipped_grads

View File

@@ -166,6 +166,9 @@ class AlgorithmParameters(Parameters):
# intrinsic reward # intrinsic reward
self.scale_external_reward_by_intrinsic_reward_value = False self.scale_external_reward_by_intrinsic_reward_value = False
# n-step returns
self.n_step = -1 # calculate the total return (no bootstrap, by default)
# Distributed Coach params # Distributed Coach params
self.distributed_coach_synchronization_type = None self.distributed_coach_synchronization_type = None

View File

@@ -125,6 +125,7 @@ class Middleware_LSTM_Embedding(MiddlewareEmbedding):
class Measurements(PredictionType): class Measurements(PredictionType):
pass pass
PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames] PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames]
@@ -162,7 +163,7 @@ class Transition(object):
self._state = self.state = state self._state = self.state = state
self._action = self.action = action self._action = self.action = action
self._reward = self.reward = reward self._reward = self.reward = reward
self._total_return = self.total_return = None self._n_step_discounted_rewards = self.n_step_discounted_rewards = None
if not next_state: if not next_state:
next_state = state next_state = state
self._next_state = self._next_state = next_state self._next_state = self._next_state = next_state
@@ -207,15 +208,15 @@ class Transition(object):
self._reward = val self._reward = val
@property @property
def total_return(self): def n_step_discounted_rewards(self):
if self._total_return is None: if self._n_step_discounted_rewards is None:
raise Exception("The total_return was not filled by any of the modules between the environment and the " raise Exception("The n_step_discounted_rewards were not filled by any of the modules between the "
"agent. Make sure that you are using an episodic experience replay.") "environment and the agent. Make sure that you are using an episodic experience replay.")
return self._total_return return self._n_step_discounted_rewards
@total_return.setter @n_step_discounted_rewards.setter
def total_return(self, val): def n_step_discounted_rewards(self, val):
self._total_return = val self._n_step_discounted_rewards = val
@property @property
def game_over(self): def game_over(self):
@@ -322,6 +323,7 @@ class ActionInfo(object):
""" """
Action info is a class that holds an action and various additional information details about it Action info is a class that holds an action and various additional information details about it
""" """
def __init__(self, action: ActionType, action_probability: float=0, def __init__(self, action: ActionType, action_probability: float=0,
action_value: float=0., state_value: float=0., max_action_value: float=None, action_value: float=0., state_value: float=0., max_action_value: float=None,
action_intrinsic_reward: float=0): action_intrinsic_reward: float=0):
@@ -359,7 +361,7 @@ class Batch(object):
self._states = {} self._states = {}
self._actions = None self._actions = None
self._rewards = None self._rewards = None
self._total_returns = None self._n_step_discounted_rewards = None
self._game_overs = None self._game_overs = None
self._next_states = {} self._next_states = {}
self._goals = None self._goals = None
@@ -380,8 +382,8 @@ class Batch(object):
self._actions = self._actions[start:end] self._actions = self._actions[start:end]
if self._rewards is not None: if self._rewards is not None:
self._rewards = self._rewards[start:end] self._rewards = self._rewards[start:end]
if self._total_returns is not None: if self._n_step_discounted_rewards is not None:
self._total_returns = self._total_returns[start:end] self._n_step_discounted_rewards = self._n_step_discounted_rewards[start:end]
if self._game_overs is not None: if self._game_overs is not None:
self._game_overs = self._game_overs[start:end] self._game_overs = self._game_overs[start:end]
for k, v in self._next_states.items(): for k, v in self._next_states.items():
@@ -402,7 +404,7 @@ class Batch(object):
self._states = {} self._states = {}
self._actions = None self._actions = None
self._rewards = None self._rewards = None
self._total_returns = None self._n_step_discounted_rewards = None
self._game_overs = None self._game_overs = None
self._next_states = {} self._next_states = {}
self._goals = None self._goals = None
@@ -471,18 +473,20 @@ class Batch(object):
return np.expand_dims(self._rewards, -1) return np.expand_dims(self._rewards, -1)
return self._rewards return self._rewards
def total_returns(self, expand_dims=False) -> np.ndarray: def n_step_discounted_rewards(self, expand_dims=False) -> np.ndarray:
""" """
if the total_returns were not converted to a batch before, extract them to a batch and then return the batch if the n_step_discounted_rewards were not converted to a batch before, extract them to a batch and then return
if the total return was not filled, this will raise an exception the batch
if the n step discounted rewards were not filled, this will raise an exception
:param expand_dims: add an extra dimension to the total_returns batch :param expand_dims: add an extra dimension to the total_returns batch
:return: a numpy array containing all the total return values of the batch :return: a numpy array containing all the total return values of the batch
""" """
if self._total_returns is None: if self._n_step_discounted_rewards is None:
self._total_returns = np.array([transition.total_return for transition in self.transitions]) self._n_step_discounted_rewards = np.array([transition.n_step_discounted_rewards for transition in
self.transitions])
if expand_dims: if expand_dims:
return np.expand_dims(self._total_returns, -1) return np.expand_dims(self._n_step_discounted_rewards, -1)
return self._total_returns return self._n_step_discounted_rewards
def game_overs(self, expand_dims=False) -> np.ndarray: def game_overs(self, expand_dims=False) -> np.ndarray:
""" """
@@ -510,7 +514,8 @@ class Batch(object):
# addition to the current_state, so that all the inputs of the network will be filled) # addition to the current_state, so that all the inputs of the network will be filled)
for key in set(fetches).intersection(self.transitions[0].next_state.keys()): for key in set(fetches).intersection(self.transitions[0].next_state.keys()):
if key not in self._next_states.keys(): if key not in self._next_states.keys():
self._next_states[key] = np.array([np.array(transition.next_state[key]) for transition in self.transitions]) self._next_states[key] = np.array(
[np.array(transition.next_state[key]) for transition in self.transitions])
if expand_dims: if expand_dims:
next_states[key] = np.expand_dims(self._next_states[key], -1) next_states[key] = np.expand_dims(self._next_states[key], -1)
else: else:
@@ -530,6 +535,16 @@ class Batch(object):
return np.expand_dims(self._goals, -1) return np.expand_dims(self._goals, -1)
return self._goals return self._goals
def info_as_list(self, key) -> list:
"""
get the info and store it internally as a list, if wasn't stored before. return it as a list
:param expand_dims: add an extra dimension to the info batch
:return: a list containing all the info values of the batch corresponding to the given key
"""
if key not in self._info.keys():
self._info[key] = [transition.info[key] for transition in self.transitions]
return self._info[key]
def info(self, key, expand_dims=False) -> np.ndarray: def info(self, key, expand_dims=False) -> np.ndarray:
""" """
if the given info dictionary key was not converted to a batch before, extract it to a batch and then return the if the given info dictionary key was not converted to a batch before, extract it to a batch and then return the
@@ -537,11 +552,11 @@ class Batch(object):
:param expand_dims: add an extra dimension to the info batch :param expand_dims: add an extra dimension to the info batch
:return: a numpy array containing all the info values of the batch corresponding to the given key :return: a numpy array containing all the info values of the batch corresponding to the given key
""" """
if key not in self._info.keys(): info_list = self.info_as_list(key)
self._info[key] = np.array([transition.info[key] for transition in self.transitions])
if expand_dims: if expand_dims:
return np.expand_dims(self._info[key], -1) return np.expand_dims(info_list, -1)
return self._info[key] return np.array(info_list)
@property @property
def size(self) -> int: def size(self) -> int:
@@ -572,6 +587,7 @@ class TotalStepsCounter(object):
""" """
A wrapper around a dictionary counting different StepMethods steps done. A wrapper around a dictionary counting different StepMethods steps done.
""" """
def __init__(self): def __init__(self):
self.counters = { self.counters = {
EnvironmentEpisodes: 0, EnvironmentEpisodes: 0,
@@ -619,7 +635,6 @@ class Episode(object):
""" """
self.transitions = [] self.transitions = []
# a num_transitions x num_transitions table with the n step return in the n'th row # a num_transitions x num_transitions table with the n step return in the n'th row
self.returns_table = None
self._length = 0 self._length = 0
self.discount = discount self.discount = discount
self.bootstrap_total_return_from_old_policy = bootstrap_total_return_from_old_policy self.bootstrap_total_return_from_old_policy = bootstrap_total_return_from_old_policy
@@ -650,28 +665,48 @@ class Episode(object):
def get_first_transition(self): def get_first_transition(self):
return self.get_transition(0) if self.length() > 0 else None return self.get_transition(0) if self.length() > 0 else None
def update_returns(self): def update_discounted_rewards(self):
if self.n_step == -1 or self.n_step > self.length(): if self.n_step == -1 or self.n_step > self.length():
curr_n_step = self.length() curr_n_step = self.length()
else: else:
curr_n_step = self.n_step curr_n_step = self.n_step
rewards = np.array([t.reward for t in self.transitions]) rewards = np.array([t.reward for t in self.transitions])
rewards = rewards.astype('float') rewards = rewards.astype('float')
total_return = rewards.copy() discounted_rewards = rewards.copy()
current_discount = self.discount current_discount = self.discount
for i in range(1, curr_n_step): for i in range(1, curr_n_step):
total_return += current_discount * np.pad(rewards[i:], (0, i), 'constant', constant_values=0) discounted_rewards += current_discount * np.pad(rewards[i:], (0, i), 'constant', constant_values=0)
current_discount *= self.discount current_discount *= self.discount
# calculate the bootstrapped returns # calculate the bootstrapped returns
if self.bootstrap_total_return_from_old_policy: if self.bootstrap_total_return_from_old_policy:
bootstraps = np.array([np.squeeze(t.info['max_action_value']) for t in self.transitions[curr_n_step:]]) bootstraps = np.array([np.squeeze(t.info['max_action_value']) for t in self.transitions[curr_n_step:]])
bootstrapped_return = total_return + current_discount * np.pad(bootstraps, (0, curr_n_step), 'constant', bootstrapped_return = discounted_rewards + current_discount * np.pad(bootstraps, (0, curr_n_step),
constant_values=0) 'constant', constant_values=0)
total_return = bootstrapped_return discounted_rewards = bootstrapped_return
for transition_idx in range(self.length()): for transition_idx in range(self.length()):
self.transitions[transition_idx].total_return = total_return[transition_idx] self.transitions[transition_idx].n_step_discounted_rewards = discounted_rewards[transition_idx]
def update_transitions_rewards_and_bootstrap_data(self):
if not isinstance(self.n_step, int) or (self.n_step < 1 and self.n_step != -1):
raise ValueError("n-step should be an integer with value >= 1, or set to -1 for always setting to episode"
" length.")
elif self.n_step > 1:
curr_n_step = self.n_step if self.n_step < self.length() else self.length()
for idx, transition in enumerate(self.transitions):
next_n_step_transition_idx = (idx + curr_n_step)
if next_n_step_transition_idx < len(self.transitions):
# next state will now point to the n-step next state
transition.next_state = self.transitions[next_n_step_transition_idx].state
transition.info['should_bootstrap_next_state'] = True
else:
transition.next_state = self.transitions[-1].next_state
transition.info['should_bootstrap_next_state'] = False
self.update_discounted_rewards()
def update_actions_probabilities(self): def update_actions_probabilities(self):
probability_product = 1 probability_product = 1
@@ -681,12 +716,6 @@ class Episode(object):
for transition_idx, transition in enumerate(self.transitions): for transition_idx, transition in enumerate(self.transitions):
transition.info['probability_product'] = probability_product transition.info['probability_product'] = probability_product
def get_returns_table(self):
return self.returns_table
def get_returns(self):
return self.get_transitions_attribute('total_return')
def get_transitions_attribute(self, attribute_name): def get_transitions_attribute(self, attribute_name):
if len(self.transitions) > 0 and hasattr(self.transitions[0], attribute_name): if len(self.transitions) > 0 and hasattr(self.transitions[0], attribute_name):
return [getattr(t, attribute_name) for t in self.transitions] return [getattr(t, attribute_name) for t in self.transitions]

View File

@@ -27,6 +27,7 @@ class EpisodicExperienceReplayParameters(MemoryParameters):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.max_size = (MemoryGranularity.Transitions, 1000000) self.max_size = (MemoryGranularity.Transitions, 1000000)
self.n_step = -1
@property @property
def path(self): def path(self):
@@ -39,12 +40,13 @@ class EpisodicExperienceReplay(Memory):
calculations of total return and other values that depend on the sequential behavior of the transitions calculations of total return and other values that depend on the sequential behavior of the transitions
in the episode. in the episode.
""" """
def __init__(self, max_size: Tuple[MemoryGranularity, int]): def __init__(self, max_size: Tuple[MemoryGranularity, int]=(MemoryGranularity.Transitions, 1000000), n_step=-1):
""" """
:param max_size: the maximum number of transitions or episodes to hold in the memory :param max_size: the maximum number of transitions or episodes to hold in the memory
""" """
super().__init__(max_size) super().__init__(max_size)
self._buffer = [Episode()] # list of episodes self.n_step = n_step
self._buffer = [Episode(n_step=self.n_step)] # list of episodes
self.transitions = [] self.transitions = []
self._length = 1 # the episodic replay buffer starts with a single empty episode self._length = 1 # the episodic replay buffer starts with a single empty episode
self._num_transitions = 0 self._num_transitions = 0
@@ -109,7 +111,7 @@ class EpisodicExperienceReplay(Memory):
self._remove_episode(0) self._remove_episode(0)
def _update_episode(self, episode: Episode) -> None: def _update_episode(self, episode: Episode) -> None:
episode.update_returns() episode.update_transitions_rewards_and_bootstrap_data()
def verify_last_episode_is_closed(self) -> None: def verify_last_episode_is_closed(self) -> None:
""" """
@@ -138,7 +140,7 @@ class EpisodicExperienceReplay(Memory):
self._length += 1 self._length += 1
# create a new Episode for the next transitions to be placed into # create a new Episode for the next transitions to be placed into
self._buffer.append(Episode()) self._buffer.append(Episode(n_step=self.n_step))
# if update episode adds to the buffer, a new Episode needs to be ready first # if update episode adds to the buffer, a new Episode needs to be ready first
# it would be better if this were less state full # it would be better if this were less state full
@@ -158,12 +160,14 @@ class EpisodicExperienceReplay(Memory):
:param transition: a transition to store :param transition: a transition to store
:return: None :return: None
""" """
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition) super().store(transition)
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
if len(self._buffer) == 0: if len(self._buffer) == 0:
self._buffer.append(Episode()) self._buffer.append(Episode(n_step=self.n_step))
last_episode = self._buffer[-1] last_episode = self._buffer[-1]
last_episode.insert(transition) last_episode.insert(transition)
self.transitions.append(transition) self.transitions.append(transition)
@@ -284,7 +288,7 @@ class EpisodicExperienceReplay(Memory):
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
self.transitions = [] self.transitions = []
self._buffer = [Episode()] self._buffer = [Episode(n_step=self.n_step)]
self._length = 1 self._length = 1
self._num_transitions = 0 self._num_transitions = 0
self._num_transitions_in_complete_episodes = 0 self._num_transitions_in_complete_episodes = 0

View File

@@ -139,7 +139,7 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay):
hindsight_transition.reward, hindsight_transition.game_over = \ hindsight_transition.reward, hindsight_transition.game_over = \
self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state) self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state)
hindsight_transition.total_return = None hindsight_transition.n_step_discounted_rewards = None
episode.insert(hindsight_transition) episode.insert(hindsight_transition)
super().store_episode(episode) super().store_episode(episode)

View File

@@ -128,14 +128,14 @@ class SegmentTree(object):
self.tree[node_idx] = new_val self.tree[node_idx] = new_val
self._propagate(node_idx) self._propagate(node_idx)
def get(self, val: float) -> Tuple[int, float, Any]: def get_element_by_partial_sum(self, val: float) -> Tuple[int, float, Any]:
""" """
Given a value between 0 and the tree sum, return the object which this value is in it's range. Given a value between 0 and the tree sum, return the object which this value is in it's range.
For example, if we have 3 leaves: 10, 20, 30, and val=35, this will return the 3rd leaf, by accumulating For example, if we have 3 leaves: 10, 20, 30, and val=35, this will return the 3rd leaf, by accumulating
leaves by their order until getting to 35. This allows sampling leaves according to their proportional leaves by their order until getting to 35. This allows sampling leaves according to their proportional
probability. probability.
:param val: a value within the range 0 and the tree sum :param val: a value within the range 0 and the tree sum
:return: the index of the resulting leaf in the tree, it's probability and :return: the index of the resulting leaf in the tree, its probability and
the object itself the object itself
""" """
node_idx = self._retrieve(0, val) node_idx = self._retrieve(0, val)
@@ -237,12 +237,12 @@ class PrioritizedExperienceReplay(ExperienceReplay):
# sample a batch # sample a batch
for i in range(size): for i in range(size):
start_probability = segment_size * i segment_start = segment_size * i
end_probability = segment_size * (i + 1) segment_end = segment_size * (i + 1)
# sample leaf and calculate its weight # sample leaf and calculate its weight
val = random.uniform(start_probability, end_probability) val = random.uniform(segment_start, segment_end)
leaf_idx, priority, transition = self.sum_tree.get(val) leaf_idx, priority, transition = self.sum_tree.get_element_by_partial_sum(val)
priority /= self.sum_tree.total_value() # P(j) = p^a / sum(p^a) priority /= self.sum_tree.total_value() # P(j) = p^a / sum(p^a)
weight = (self.num_transitions() * priority) ** -self.beta.current_value # (N * P(j)) ^ -beta weight = (self.num_transitions() * priority) ** -self.beta.current_value # (N * P(j)) ^ -beta
normalized_weight = weight / max_weight # wj = ((N * P(j)) ^ -beta) / max wi normalized_weight = weight / max_weight # wj = ((N * P(j)) ^ -beta) / max wi
@@ -261,7 +261,7 @@ class PrioritizedExperienceReplay(ExperienceReplay):
self.reader_writer_lock.release_writing() self.reader_writer_lock.release_writing()
return batch return batch
def store(self, transition: Transition) -> None: def store(self, transition: Transition, lock=True) -> None:
""" """
Store a new transition in the memory. Store a new transition in the memory.
:param transition: a transition to store :param transition: a transition to store
@@ -270,7 +270,8 @@ class PrioritizedExperienceReplay(ExperienceReplay):
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition) super().store(transition)
self.reader_writer_lock.lock_writing_and_reading() if lock:
self.reader_writer_lock.lock_writing_and_reading()
transition_priority = self.maximal_priority transition_priority = self.maximal_priority
self.sum_tree.add(transition_priority ** self.alpha, transition) self.sum_tree.add(transition_priority ** self.alpha, transition)
@@ -278,18 +279,21 @@ class PrioritizedExperienceReplay(ExperienceReplay):
self.max_tree.add(transition_priority, transition) self.max_tree.add(transition_priority, transition)
super().store(transition, False) super().store(transition, False)
self.reader_writer_lock.release_writing_and_reading() if lock:
self.reader_writer_lock.release_writing_and_reading()
def clean(self) -> None: def clean(self, lock=True) -> None:
""" """
Clean the memory by removing all the episodes Clean the memory by removing all the episodes
:return: None :return: None
""" """
self.reader_writer_lock.lock_writing_and_reading() if lock:
self.reader_writer_lock.lock_writing_and_reading()
super().clean(lock=False) super().clean(lock=False)
self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM) self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM)
self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN) self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN)
self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX) self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX)
self.reader_writer_lock.release_writing_and_reading() if lock:
self.reader_writer_lock.release_writing_and_reading()

View File

@@ -63,7 +63,7 @@ env_params = GymVectorEnvironment(level='CartPole-v0')
preset_validation_params = PresetValidationParameters() preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150 preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 250 preset_validation_params.max_episodes_to_achieve_reward = 400
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=VisualizationParameters(), schedule_params=schedule_params, vis_params=VisualizationParameters(),

View File

@@ -26,10 +26,10 @@ def test_sum_tree():
sum_tree.add(5, "5") sum_tree.add(5, "5")
assert sum_tree.total_value() == 20 assert sum_tree.total_value() == 20
assert sum_tree.get(2) == (0, 2.5, '2.5') assert sum_tree.get_element_by_partial_sum(2) == (0, 2.5, '2.5')
assert sum_tree.get(3) == (1, 5.0, '5') assert sum_tree.get_element_by_partial_sum(3) == (1, 5.0, '5')
assert sum_tree.get(10) == (2, 5.0, '5') assert sum_tree.get_element_by_partial_sum(10) == (2, 5.0, '5')
assert sum_tree.get(13) == (3, 7.5, '7.5') assert sum_tree.get_element_by_partial_sum(13) == (3, 7.5, '7.5')
sum_tree.update(2, 10) sum_tree.update(2, 10)
assert sum_tree.__str__() == "[25.]\n[ 7.5 17.5]\n[ 2.5 5. 10. 7.5]\n" assert sum_tree.__str__() == "[25.]\n[ 7.5 17.5]\n[ 2.5 5. 10. 7.5]\n"

View File

@@ -41,8 +41,8 @@ def test_store_and_get(buffer: SingleEpisodeBuffer):
# check that the episode is valid # check that the episode is valid
episode = buffer.get(0) episode = buffer.get(0)
assert episode.length() == 2 assert episode.length() == 2
assert episode.get_transition(0).total_return == 1 + 0.99 assert episode.get_transition(0).n_step_discounted_rewards == 1 + 0.99
assert episode.get_transition(1).total_return == 1 assert episode.get_transition(1).n_step_discounted_rewards == 1
assert buffer.mean_reward() == 1 assert buffer.mean_reward() == 1
# only one episode in the replay buffer # only one episode in the replay buffer