mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
N-step returns for rainbow (#67)
* n_step returns for rainbow * Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
@@ -116,7 +116,6 @@ class Agent(AgentInterface):
|
|||||||
self.output_filter.set_device(device)
|
self.output_filter.set_device(device)
|
||||||
self.pre_network_filter.set_device(device)
|
self.pre_network_filter.set_device(device)
|
||||||
|
|
||||||
|
|
||||||
# initialize all internal variables
|
# initialize all internal variables
|
||||||
self._phase = RunPhase.HEATUP
|
self._phase = RunPhase.HEATUP
|
||||||
self.total_shaped_reward_in_current_episode = 0
|
self.total_shaped_reward_in_current_episode = 0
|
||||||
@@ -143,7 +142,7 @@ class Agent(AgentInterface):
|
|||||||
self.accumulated_shaped_rewards_across_evaluation_episodes = 0
|
self.accumulated_shaped_rewards_across_evaluation_episodes = 0
|
||||||
self.num_successes_across_evaluation_episodes = 0
|
self.num_successes_across_evaluation_episodes = 0
|
||||||
self.num_evaluation_episodes_completed = 0
|
self.num_evaluation_episodes_completed = 0
|
||||||
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount)
|
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount, n_step=self.ap.algorithm.n_step)
|
||||||
# TODO: add agents observation rendering for debugging purposes (not the same as the environment rendering)
|
# TODO: add agents observation rendering for debugging purposes (not the same as the environment rendering)
|
||||||
|
|
||||||
# environment parameters
|
# environment parameters
|
||||||
@@ -452,10 +451,10 @@ class Agent(AgentInterface):
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
self.current_episode_buffer.is_complete = True
|
self.current_episode_buffer.is_complete = True
|
||||||
self.current_episode_buffer.update_returns()
|
self.current_episode_buffer.update_transitions_rewards_and_bootstrap_data()
|
||||||
|
|
||||||
for transition in self.current_episode_buffer.transitions:
|
for transition in self.current_episode_buffer.transitions:
|
||||||
self.discounted_return.add_sample(transition.total_return)
|
self.discounted_return.add_sample(transition.n_step_discounted_rewards)
|
||||||
|
|
||||||
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
|
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
|
||||||
self.current_episode += 1
|
self.current_episode += 1
|
||||||
@@ -497,7 +496,7 @@ class Agent(AgentInterface):
|
|||||||
self.curr_state = {}
|
self.curr_state = {}
|
||||||
self.current_episode_steps_counter = 0
|
self.current_episode_steps_counter = 0
|
||||||
self.episode_running_info = {}
|
self.episode_running_info = {}
|
||||||
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount)
|
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount, n_step=self.ap.algorithm.n_step)
|
||||||
if self.exploration_policy:
|
if self.exploration_policy:
|
||||||
self.exploration_policy.reset()
|
self.exploration_policy.reset()
|
||||||
self.input_filter.reset()
|
self.input_filter.reset()
|
||||||
|
|||||||
@@ -17,14 +17,11 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters, DQNAgentParameters
|
from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters, DQNAgentParameters
|
||||||
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
|
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
|
||||||
from rl_coach.architectures.head_parameters import CategoricalQHeadParameters
|
from rl_coach.architectures.head_parameters import CategoricalQHeadParameters
|
||||||
from rl_coach.base_parameters import AgentParameters
|
|
||||||
from rl_coach.core_types import StateType
|
from rl_coach.core_types import StateType
|
||||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
|
||||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
|
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
|
||||||
from rl_coach.schedules import LinearSchedule
|
from rl_coach.schedules import LinearSchedule
|
||||||
|
|
||||||
@@ -85,28 +82,47 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
|||||||
|
|
||||||
# for the action we actually took, the error is calculated by the atoms distribution
|
# for the action we actually took, the error is calculated by the atoms distribution
|
||||||
# for all other actions, the error is 0
|
# for all other actions, the error is 0
|
||||||
distributed_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([
|
distributional_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([
|
||||||
(self.networks['main'].target_network, batch.next_states(network_keys)),
|
(self.networks['main'].target_network, batch.next_states(network_keys)),
|
||||||
(self.networks['main'].online_network, batch.states(network_keys))
|
(self.networks['main'].online_network, batch.states(network_keys))
|
||||||
])
|
])
|
||||||
|
|
||||||
# only update the action that we have actually done in this transition
|
# select the optimal actions for the next state
|
||||||
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributed_q_st_plus_1), axis=1)
|
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
|
||||||
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||||
|
|
||||||
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
|
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
|
||||||
|
|
||||||
|
# an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.
|
||||||
|
# only 10% speedup overall - leaving commented out as the code is not as clear.
|
||||||
|
|
||||||
|
# tzj_ = np.fmax(np.fmin(batch.rewards() + (1.0 - batch.game_overs()) * self.ap.algorithm.discount *
|
||||||
|
# np.transpose(np.repeat(self.z_values[np.newaxis, :], batch.size, axis=0), (1, 0)),
|
||||||
|
# self.z_values[-1]),
|
||||||
|
# self.z_values[0])
|
||||||
|
#
|
||||||
|
# bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])
|
||||||
|
# u_ = (np.ceil(bj_)).astype(int)
|
||||||
|
# l_ = (np.floor(bj_)).astype(int)
|
||||||
|
# m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||||
|
# np.add.at(m_, [batches, l_],
|
||||||
|
# np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))
|
||||||
|
# np.add.at(m_, [batches, u_],
|
||||||
|
# np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (bj_ - l_))
|
||||||
|
|
||||||
for j in range(self.z_values.size):
|
for j in range(self.z_values.size):
|
||||||
tzj = np.fmax(np.fmin(batch.rewards() +
|
tzj = np.fmax(np.fmin(batch.rewards() +
|
||||||
(1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j],
|
(1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j],
|
||||||
self.z_values[self.z_values.size - 1]),
|
self.z_values[-1]),
|
||||||
self.z_values[0])
|
self.z_values[0])
|
||||||
bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])
|
bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])
|
||||||
u = (np.ceil(bj)).astype(int)
|
u = (np.ceil(bj)).astype(int)
|
||||||
l = (np.floor(bj)).astype(int)
|
l = (np.floor(bj)).astype(int)
|
||||||
m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj))
|
m[batches, l] += (distributional_q_st_plus_1[batches, target_actions, j] * (u - bj))
|
||||||
m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l))
|
m[batches, u] += (distributional_q_st_plus_1[batches, target_actions, j] * (bj - l))
|
||||||
|
|
||||||
# total_loss = cross entropy between actual result above and predicted result for the given action
|
# total_loss = cross entropy between actual result above and predicted result for the given action
|
||||||
|
# only update the action that we have actually done in this transition
|
||||||
TD_targets[batches, batch.actions()] = m
|
TD_targets[batches, batch.actions()] = m
|
||||||
|
|
||||||
# update errors in prioritized replay buffer
|
# update errors in prioritized replay buffer
|
||||||
@@ -120,7 +136,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
|||||||
# TODO: fix this spaghetti code
|
# TODO: fix this spaghetti code
|
||||||
if isinstance(self.memory, PrioritizedExperienceReplay):
|
if isinstance(self.memory, PrioritizedExperienceReplay):
|
||||||
errors = losses[0][np.arange(batch.size), batch.actions()]
|
errors = losses[0][np.arange(batch.size), batch.actions()]
|
||||||
self.memory.update_priorities(batch.info('idx'), errors)
|
self.call_memory('update_priorities', (batch.info('idx'), errors))
|
||||||
|
|
||||||
return total_loss, losses, unclipped_grads
|
return total_loss, losses, unclipped_grads
|
||||||
|
|
||||||
|
|||||||
@@ -116,8 +116,10 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
# calculate advantages
|
# calculate advantages
|
||||||
advantages = []
|
advantages = []
|
||||||
value_targets = []
|
value_targets = []
|
||||||
|
total_returns = batch.n_step_discounted_rewards()
|
||||||
|
|
||||||
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
|
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
|
||||||
advantages = batch.total_returns() - current_state_values
|
advantages = total_returns - current_state_values
|
||||||
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
|
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
|
||||||
# get bootstraps
|
# get bootstraps
|
||||||
episode_start_idx = 0
|
episode_start_idx = 0
|
||||||
@@ -181,11 +183,13 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
result = self.networks['main'].target_network.predict({k: v[start:end] for k, v in batch.states(network_keys).items()})
|
result = self.networks['main'].target_network.predict({k: v[start:end] for k, v in batch.states(network_keys).items()})
|
||||||
old_policy_distribution = result[1:]
|
old_policy_distribution = result[1:]
|
||||||
|
|
||||||
|
total_returns = batch.n_step_discounted_rewards(expand_dims=True)
|
||||||
|
|
||||||
# calculate gradients and apply on both the local policy network and on the global policy network
|
# calculate gradients and apply on both the local policy network and on the global policy network
|
||||||
if self.ap.algorithm.estimate_state_value_using_gae:
|
if self.ap.algorithm.estimate_state_value_using_gae:
|
||||||
value_targets = np.expand_dims(gae_based_value_targets, -1)
|
value_targets = np.expand_dims(gae_based_value_targets, -1)
|
||||||
else:
|
else:
|
||||||
value_targets = batch.total_returns(expand_dims=True)[start:end]
|
value_targets = total_returns[start:end]
|
||||||
|
|
||||||
inputs = copy.copy({k: v[start:end] for k, v in batch.states(network_keys).items()})
|
inputs = copy.copy({k: v[start:end] for k, v in batch.states(network_keys).items()})
|
||||||
inputs['output_1_0'] = actions
|
inputs['output_1_0'] = actions
|
||||||
|
|||||||
@@ -58,11 +58,13 @@ class MixedMonteCarloAgent(ValueOptimizationAgent):
|
|||||||
(self.networks['main'].online_network, batch.states(network_keys))
|
(self.networks['main'].online_network, batch.states(network_keys))
|
||||||
])
|
])
|
||||||
|
|
||||||
|
total_returns = batch.n_step_discounted_rewards()
|
||||||
|
|
||||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||||
one_step_target = batch.rewards()[i] + \
|
one_step_target = batch.rewards()[i] + \
|
||||||
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
|
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
|
||||||
q_st_plus_1[i][selected_actions[i]]
|
q_st_plus_1[i][selected_actions[i]]
|
||||||
monte_carlo_target = batch.total_returns()[i]
|
monte_carlo_target = total_returns()[i]
|
||||||
TD_targets[i, batch.actions()[i]] = (1 - self.mixing_rate) * one_step_target + \
|
TD_targets[i, batch.actions()[i]] = (1 - self.mixing_rate) * one_step_target + \
|
||||||
self.mixing_rate * monte_carlo_target
|
self.mixing_rate * monte_carlo_target
|
||||||
|
|
||||||
|
|||||||
@@ -98,10 +98,10 @@ class NECAgent(ValueOptimizationAgent):
|
|||||||
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
||||||
|
|
||||||
TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys))
|
TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys))
|
||||||
|
bootstrapped_return_from_old_policy = batch.n_step_discounted_rewards()
|
||||||
# only update the action that we have actually done in this transition
|
# only update the action that we have actually done in this transition
|
||||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||||
TD_targets[i, batch.actions()[i]] = batch.total_returns()[i]
|
TD_targets[i, batch.actions()[i]] = bootstrapped_return_from_old_policy[i]
|
||||||
|
|
||||||
# set the gradients to fetch for the DND update
|
# set the gradients to fetch for the DND update
|
||||||
fetches = []
|
fetches = []
|
||||||
@@ -165,10 +165,10 @@ class NECAgent(ValueOptimizationAgent):
|
|||||||
episode = self.call_memory('get_last_complete_episode')
|
episode = self.call_memory('get_last_complete_episode')
|
||||||
if episode is not None and self.phase != RunPhase.TEST:
|
if episode is not None and self.phase != RunPhase.TEST:
|
||||||
assert len(self.current_episode_state_embeddings) == episode.length()
|
assert len(self.current_episode_state_embeddings) == episode.length()
|
||||||
returns = episode.get_transitions_attribute('total_return')
|
discounted_rewards = episode.get_transitions_attribute('n_step_discounted_rewards')
|
||||||
actions = episode.get_transitions_attribute('action')
|
actions = episode.get_transitions_attribute('action')
|
||||||
self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings,
|
self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings,
|
||||||
actions, returns)
|
actions, discounted_rewards)
|
||||||
|
|
||||||
def save_checkpoint(self, checkpoint_id):
|
def save_checkpoint(self, checkpoint_id):
|
||||||
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:
|
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ class PALAgent(ValueOptimizationAgent):
|
|||||||
|
|
||||||
# calculate TD error
|
# calculate TD error
|
||||||
TD_targets = np.copy(q_st_online)
|
TD_targets = np.copy(q_st_online)
|
||||||
|
total_returns = batch.n_step_discounted_rewards()
|
||||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||||
TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \
|
TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \
|
||||||
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
|
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
|
||||||
@@ -83,7 +84,7 @@ class PALAgent(ValueOptimizationAgent):
|
|||||||
TD_targets[i, batch.actions()[i]] -= self.alpha * advantage_learning_update
|
TD_targets[i, batch.actions()[i]] -= self.alpha * advantage_learning_update
|
||||||
|
|
||||||
# mixing monte carlo updates
|
# mixing monte carlo updates
|
||||||
monte_carlo_target = batch.total_returns()[i]
|
monte_carlo_target = total_returns[i]
|
||||||
TD_targets[i, batch.actions()[i]] = (1 - self.monte_carlo_mixing_rate) * TD_targets[i, batch.actions()[i]] \
|
TD_targets[i, batch.actions()[i]] = (1 - self.monte_carlo_mixing_rate) * TD_targets[i, batch.actions()[i]] \
|
||||||
+ self.monte_carlo_mixing_rate * monte_carlo_target
|
+ self.monte_carlo_mixing_rate * monte_carlo_target
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class PolicyGradientsAgent(PolicyOptimizationAgent):
|
|||||||
# batch contains a list of episodes to learn from
|
# batch contains a list of episodes to learn from
|
||||||
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
||||||
|
|
||||||
total_returns = batch.total_returns()
|
total_returns = batch.n_step_discounted_rewards()
|
||||||
for i in reversed(range(batch.size)):
|
for i in reversed(range(batch.size)):
|
||||||
if self.policy_gradient_rescaler == PolicyGradientRescaler.TOTAL_RETURN:
|
if self.policy_gradient_rescaler == PolicyGradientRescaler.TOTAL_RETURN:
|
||||||
total_returns[i] = total_returns[0]
|
total_returns[i] = total_returns[0]
|
||||||
|
|||||||
@@ -73,11 +73,11 @@ class PolicyOptimizationAgent(Agent):
|
|||||||
episode_discounted_returns = []
|
episode_discounted_returns = []
|
||||||
for i in range(episode.length()):
|
for i in range(episode.length()):
|
||||||
transition = episode.get_transition(i)
|
transition = episode.get_transition(i)
|
||||||
episode_discounted_returns.append(transition.total_return)
|
episode_discounted_returns.append(transition.n_step_discounted_rewards)
|
||||||
self.num_episodes_where_step_has_been_seen[i] += 1
|
self.num_episodes_where_step_has_been_seen[i] += 1
|
||||||
self.mean_return_over_multiple_episodes[i] -= self.mean_return_over_multiple_episodes[i] / \
|
self.mean_return_over_multiple_episodes[i] -= self.mean_return_over_multiple_episodes[i] / \
|
||||||
self.num_episodes_where_step_has_been_seen[i]
|
self.num_episodes_where_step_has_been_seen[i]
|
||||||
self.mean_return_over_multiple_episodes[i] += transition.total_return / \
|
self.mean_return_over_multiple_episodes[i] += transition.n_step_discounted_rewards / \
|
||||||
self.num_episodes_where_step_has_been_seen[i]
|
self.num_episodes_where_step_has_been_seen[i]
|
||||||
self.mean_discounted_return = np.mean(episode_discounted_returns)
|
self.mean_discounted_return = np.mean(episode_discounted_returns)
|
||||||
self.std_discounted_return = np.std(episode_discounted_returns)
|
self.std_discounted_return = np.std(episode_discounted_returns)
|
||||||
@@ -97,7 +97,7 @@ class PolicyOptimizationAgent(Agent):
|
|||||||
network.set_is_training(True)
|
network.set_is_training(True)
|
||||||
|
|
||||||
# we need to update the returns of the episode until now
|
# we need to update the returns of the episode until now
|
||||||
episode.update_returns()
|
episode.update_transitions_rewards_and_bootstrap_data()
|
||||||
|
|
||||||
# get t_max transitions or less if the we got to a terminal state
|
# get t_max transitions or less if the we got to a terminal state
|
||||||
# will be used for both actor-critic and vanilla PG.
|
# will be used for both actor-critic and vanilla PG.
|
||||||
|
|||||||
@@ -112,11 +112,11 @@ class PPOAgent(ActorCriticAgent):
|
|||||||
# current_states_with_timestep = self.concat_state_and_timestep(batch)
|
# current_states_with_timestep = self.concat_state_and_timestep(batch)
|
||||||
|
|
||||||
current_state_values = self.networks['critic'].online_network.predict(batch.states(network_keys)).squeeze()
|
current_state_values = self.networks['critic'].online_network.predict(batch.states(network_keys)).squeeze()
|
||||||
|
total_returns = batch.n_step_discounted_rewards()
|
||||||
# calculate advantages
|
# calculate advantages
|
||||||
advantages = []
|
advantages = []
|
||||||
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
|
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
|
||||||
advantages = batch.total_returns() - current_state_values
|
advantages = total_returns - current_state_values
|
||||||
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
|
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
|
||||||
# get bootstraps
|
# get bootstraps
|
||||||
episode_start_idx = 0
|
episode_start_idx = 0
|
||||||
@@ -155,6 +155,7 @@ class PPOAgent(ActorCriticAgent):
|
|||||||
# current_states_with_timestep = self.concat_state_and_timestep(dataset)
|
# current_states_with_timestep = self.concat_state_and_timestep(dataset)
|
||||||
|
|
||||||
mix_fraction = self.ap.algorithm.value_targets_mix_fraction
|
mix_fraction = self.ap.algorithm.value_targets_mix_fraction
|
||||||
|
total_returns = batch.n_step_discounted_rewards(True)
|
||||||
for j in range(epochs):
|
for j in range(epochs):
|
||||||
curr_batch_size = batch.size
|
curr_batch_size = batch.size
|
||||||
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
|
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
|
||||||
@@ -165,7 +166,7 @@ class PPOAgent(ActorCriticAgent):
|
|||||||
k: v[i * curr_batch_size:(i + 1) * curr_batch_size]
|
k: v[i * curr_batch_size:(i + 1) * curr_batch_size]
|
||||||
for k, v in batch.states(network_keys).items()
|
for k, v in batch.states(network_keys).items()
|
||||||
}
|
}
|
||||||
total_return_batch = batch.total_returns(True)[i * curr_batch_size:(i + 1) * curr_batch_size]
|
total_return_batch = total_returns[i * curr_batch_size:(i + 1) * curr_batch_size]
|
||||||
old_policy_values = force_list(self.networks['critic'].target_network.predict(
|
old_policy_values = force_list(self.networks['critic'].target_network.predict(
|
||||||
current_states_batch).squeeze())
|
current_states_batch).squeeze())
|
||||||
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
|
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
|
||||||
|
|||||||
@@ -39,23 +39,17 @@ class RainbowDQNNetworkParameters(DQNNetworkParameters):
|
|||||||
class RainbowDQNAlgorithmParameters(CategoricalDQNAlgorithmParameters):
|
class RainbowDQNAlgorithmParameters(CategoricalDQNAlgorithmParameters):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.n_step = 3
|
||||||
|
|
||||||
|
# needed for n-step updates to work. i.e. waiting for a full episode to be closed before storing each transition
|
||||||
class RainbowDQNExplorationParameters(ParameterNoiseParameters):
|
self.store_transitions_only_when_episodes_are_terminated = True
|
||||||
def __init__(self, agent_params):
|
|
||||||
super().__init__(agent_params)
|
|
||||||
|
|
||||||
|
|
||||||
class RainbowDQNMemoryParameters(PrioritizedExperienceReplayParameters):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
|
|
||||||
class RainbowDQNAgentParameters(CategoricalDQNAgentParameters):
|
class RainbowDQNAgentParameters(CategoricalDQNAgentParameters):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.algorithm = RainbowDQNAlgorithmParameters()
|
self.algorithm = RainbowDQNAlgorithmParameters()
|
||||||
self.exploration = RainbowDQNExplorationParameters(self)
|
self.exploration = ParameterNoiseParameters(self)
|
||||||
self.memory = PrioritizedExperienceReplayParameters()
|
self.memory = PrioritizedExperienceReplayParameters()
|
||||||
self.network_wrappers = {"main": RainbowDQNNetworkParameters()}
|
self.network_wrappers = {"main": RainbowDQNNetworkParameters()}
|
||||||
|
|
||||||
@@ -65,15 +59,13 @@ class RainbowDQNAgentParameters(CategoricalDQNAgentParameters):
|
|||||||
|
|
||||||
|
|
||||||
# Rainbow Deep Q Network - https://arxiv.org/abs/1710.02298
|
# Rainbow Deep Q Network - https://arxiv.org/abs/1710.02298
|
||||||
# Agent implementation is WIP. Currently is composed of:
|
# Agent implementation is composed of:
|
||||||
# 1. NoisyNets
|
# 1. NoisyNets
|
||||||
# 2. C51
|
# 2. C51
|
||||||
# 3. Prioritized ER
|
# 3. Prioritized ER
|
||||||
# 4. DDQN
|
# 4. DDQN
|
||||||
# 5. Dueling DQN
|
# 5. Dueling DQN
|
||||||
#
|
# 6. N-step returns
|
||||||
# still missing:
|
|
||||||
# 1. N-Step
|
|
||||||
|
|
||||||
class RainbowDQNAgent(CategoricalDQNAgent):
|
class RainbowDQNAgent(CategoricalDQNAgent):
|
||||||
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
||||||
@@ -87,7 +79,7 @@ class RainbowDQNAgent(CategoricalDQNAgent):
|
|||||||
|
|
||||||
# for the action we actually took, the error is calculated by the atoms distribution
|
# for the action we actually took, the error is calculated by the atoms distribution
|
||||||
# for all other actions, the error is 0
|
# for all other actions, the error is 0
|
||||||
distributed_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([
|
distributional_q_st_plus_n, TD_targets = self.networks['main'].parallel_prediction([
|
||||||
(self.networks['main'].target_network, batch.next_states(network_keys)),
|
(self.networks['main'].target_network, batch.next_states(network_keys)),
|
||||||
(self.networks['main'].online_network, batch.states(network_keys))
|
(self.networks['main'].online_network, batch.states(network_keys))
|
||||||
])
|
])
|
||||||
@@ -98,15 +90,16 @@ class RainbowDQNAgent(CategoricalDQNAgent):
|
|||||||
|
|
||||||
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
|
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
|
||||||
for j in range(self.z_values.size):
|
for j in range(self.z_values.size):
|
||||||
tzj = np.fmax(np.fmin(batch.rewards() +
|
# we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step,
|
||||||
(1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j],
|
# we will not bootstrap for the last n-step transitions in the episode
|
||||||
self.z_values[self.z_values.size - 1]),
|
tzj = np.fmax(np.fmin(batch.n_step_discounted_rewards() + batch.info('should_bootstrap_next_state') *
|
||||||
self.z_values[0])
|
(self.ap.algorithm.discount ** self.ap.algorithm.n_step) * self.z_values[j],
|
||||||
|
self.z_values[-1]), self.z_values[0])
|
||||||
bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])
|
bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])
|
||||||
u = (np.ceil(bj)).astype(int)
|
u = (np.ceil(bj)).astype(int)
|
||||||
l = (np.floor(bj)).astype(int)
|
l = (np.floor(bj)).astype(int)
|
||||||
m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj))
|
m[batches, l] += (distributional_q_st_plus_n[batches, target_actions, j] * (u - bj))
|
||||||
m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l))
|
m[batches, u] += (distributional_q_st_plus_n[batches, target_actions, j] * (bj - l))
|
||||||
|
|
||||||
# total_loss = cross entropy between actual result above and predicted result for the given action
|
# total_loss = cross entropy between actual result above and predicted result for the given action
|
||||||
TD_targets[batches, batch.actions()] = m
|
TD_targets[batches, batch.actions()] = m
|
||||||
@@ -122,7 +115,7 @@ class RainbowDQNAgent(CategoricalDQNAgent):
|
|||||||
# TODO: fix this spaghetti code
|
# TODO: fix this spaghetti code
|
||||||
if isinstance(self.memory, PrioritizedExperienceReplay):
|
if isinstance(self.memory, PrioritizedExperienceReplay):
|
||||||
errors = losses[0][np.arange(batch.size), batch.actions()]
|
errors = losses[0][np.arange(batch.size), batch.actions()]
|
||||||
self.memory.update_priorities(batch.info('idx'), errors)
|
self.call_memory('update_priorities', (batch.info('idx'), errors))
|
||||||
|
|
||||||
return total_loss, losses, unclipped_grads
|
return total_loss, losses, unclipped_grads
|
||||||
|
|
||||||
|
|||||||
@@ -166,6 +166,9 @@ class AlgorithmParameters(Parameters):
|
|||||||
# intrinsic reward
|
# intrinsic reward
|
||||||
self.scale_external_reward_by_intrinsic_reward_value = False
|
self.scale_external_reward_by_intrinsic_reward_value = False
|
||||||
|
|
||||||
|
# n-step returns
|
||||||
|
self.n_step = -1 # calculate the total return (no bootstrap, by default)
|
||||||
|
|
||||||
# Distributed Coach params
|
# Distributed Coach params
|
||||||
self.distributed_coach_synchronization_type = None
|
self.distributed_coach_synchronization_type = None
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ class Middleware_LSTM_Embedding(MiddlewareEmbedding):
|
|||||||
class Measurements(PredictionType):
|
class Measurements(PredictionType):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames]
|
PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames]
|
||||||
|
|
||||||
|
|
||||||
@@ -162,7 +163,7 @@ class Transition(object):
|
|||||||
self._state = self.state = state
|
self._state = self.state = state
|
||||||
self._action = self.action = action
|
self._action = self.action = action
|
||||||
self._reward = self.reward = reward
|
self._reward = self.reward = reward
|
||||||
self._total_return = self.total_return = None
|
self._n_step_discounted_rewards = self.n_step_discounted_rewards = None
|
||||||
if not next_state:
|
if not next_state:
|
||||||
next_state = state
|
next_state = state
|
||||||
self._next_state = self._next_state = next_state
|
self._next_state = self._next_state = next_state
|
||||||
@@ -207,15 +208,15 @@ class Transition(object):
|
|||||||
self._reward = val
|
self._reward = val
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_return(self):
|
def n_step_discounted_rewards(self):
|
||||||
if self._total_return is None:
|
if self._n_step_discounted_rewards is None:
|
||||||
raise Exception("The total_return was not filled by any of the modules between the environment and the "
|
raise Exception("The n_step_discounted_rewards were not filled by any of the modules between the "
|
||||||
"agent. Make sure that you are using an episodic experience replay.")
|
"environment and the agent. Make sure that you are using an episodic experience replay.")
|
||||||
return self._total_return
|
return self._n_step_discounted_rewards
|
||||||
|
|
||||||
@total_return.setter
|
@n_step_discounted_rewards.setter
|
||||||
def total_return(self, val):
|
def n_step_discounted_rewards(self, val):
|
||||||
self._total_return = val
|
self._n_step_discounted_rewards = val
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def game_over(self):
|
def game_over(self):
|
||||||
@@ -322,6 +323,7 @@ class ActionInfo(object):
|
|||||||
"""
|
"""
|
||||||
Action info is a class that holds an action and various additional information details about it
|
Action info is a class that holds an action and various additional information details about it
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, action: ActionType, action_probability: float=0,
|
def __init__(self, action: ActionType, action_probability: float=0,
|
||||||
action_value: float=0., state_value: float=0., max_action_value: float=None,
|
action_value: float=0., state_value: float=0., max_action_value: float=None,
|
||||||
action_intrinsic_reward: float=0):
|
action_intrinsic_reward: float=0):
|
||||||
@@ -359,7 +361,7 @@ class Batch(object):
|
|||||||
self._states = {}
|
self._states = {}
|
||||||
self._actions = None
|
self._actions = None
|
||||||
self._rewards = None
|
self._rewards = None
|
||||||
self._total_returns = None
|
self._n_step_discounted_rewards = None
|
||||||
self._game_overs = None
|
self._game_overs = None
|
||||||
self._next_states = {}
|
self._next_states = {}
|
||||||
self._goals = None
|
self._goals = None
|
||||||
@@ -380,8 +382,8 @@ class Batch(object):
|
|||||||
self._actions = self._actions[start:end]
|
self._actions = self._actions[start:end]
|
||||||
if self._rewards is not None:
|
if self._rewards is not None:
|
||||||
self._rewards = self._rewards[start:end]
|
self._rewards = self._rewards[start:end]
|
||||||
if self._total_returns is not None:
|
if self._n_step_discounted_rewards is not None:
|
||||||
self._total_returns = self._total_returns[start:end]
|
self._n_step_discounted_rewards = self._n_step_discounted_rewards[start:end]
|
||||||
if self._game_overs is not None:
|
if self._game_overs is not None:
|
||||||
self._game_overs = self._game_overs[start:end]
|
self._game_overs = self._game_overs[start:end]
|
||||||
for k, v in self._next_states.items():
|
for k, v in self._next_states.items():
|
||||||
@@ -402,7 +404,7 @@ class Batch(object):
|
|||||||
self._states = {}
|
self._states = {}
|
||||||
self._actions = None
|
self._actions = None
|
||||||
self._rewards = None
|
self._rewards = None
|
||||||
self._total_returns = None
|
self._n_step_discounted_rewards = None
|
||||||
self._game_overs = None
|
self._game_overs = None
|
||||||
self._next_states = {}
|
self._next_states = {}
|
||||||
self._goals = None
|
self._goals = None
|
||||||
@@ -471,18 +473,20 @@ class Batch(object):
|
|||||||
return np.expand_dims(self._rewards, -1)
|
return np.expand_dims(self._rewards, -1)
|
||||||
return self._rewards
|
return self._rewards
|
||||||
|
|
||||||
def total_returns(self, expand_dims=False) -> np.ndarray:
|
def n_step_discounted_rewards(self, expand_dims=False) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
if the total_returns were not converted to a batch before, extract them to a batch and then return the batch
|
if the n_step_discounted_rewards were not converted to a batch before, extract them to a batch and then return
|
||||||
if the total return was not filled, this will raise an exception
|
the batch
|
||||||
|
if the n step discounted rewards were not filled, this will raise an exception
|
||||||
:param expand_dims: add an extra dimension to the total_returns batch
|
:param expand_dims: add an extra dimension to the total_returns batch
|
||||||
:return: a numpy array containing all the total return values of the batch
|
:return: a numpy array containing all the total return values of the batch
|
||||||
"""
|
"""
|
||||||
if self._total_returns is None:
|
if self._n_step_discounted_rewards is None:
|
||||||
self._total_returns = np.array([transition.total_return for transition in self.transitions])
|
self._n_step_discounted_rewards = np.array([transition.n_step_discounted_rewards for transition in
|
||||||
|
self.transitions])
|
||||||
if expand_dims:
|
if expand_dims:
|
||||||
return np.expand_dims(self._total_returns, -1)
|
return np.expand_dims(self._n_step_discounted_rewards, -1)
|
||||||
return self._total_returns
|
return self._n_step_discounted_rewards
|
||||||
|
|
||||||
def game_overs(self, expand_dims=False) -> np.ndarray:
|
def game_overs(self, expand_dims=False) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@@ -510,7 +514,8 @@ class Batch(object):
|
|||||||
# addition to the current_state, so that all the inputs of the network will be filled)
|
# addition to the current_state, so that all the inputs of the network will be filled)
|
||||||
for key in set(fetches).intersection(self.transitions[0].next_state.keys()):
|
for key in set(fetches).intersection(self.transitions[0].next_state.keys()):
|
||||||
if key not in self._next_states.keys():
|
if key not in self._next_states.keys():
|
||||||
self._next_states[key] = np.array([np.array(transition.next_state[key]) for transition in self.transitions])
|
self._next_states[key] = np.array(
|
||||||
|
[np.array(transition.next_state[key]) for transition in self.transitions])
|
||||||
if expand_dims:
|
if expand_dims:
|
||||||
next_states[key] = np.expand_dims(self._next_states[key], -1)
|
next_states[key] = np.expand_dims(self._next_states[key], -1)
|
||||||
else:
|
else:
|
||||||
@@ -530,6 +535,16 @@ class Batch(object):
|
|||||||
return np.expand_dims(self._goals, -1)
|
return np.expand_dims(self._goals, -1)
|
||||||
return self._goals
|
return self._goals
|
||||||
|
|
||||||
|
def info_as_list(self, key) -> list:
|
||||||
|
"""
|
||||||
|
get the info and store it internally as a list, if wasn't stored before. return it as a list
|
||||||
|
:param expand_dims: add an extra dimension to the info batch
|
||||||
|
:return: a list containing all the info values of the batch corresponding to the given key
|
||||||
|
"""
|
||||||
|
if key not in self._info.keys():
|
||||||
|
self._info[key] = [transition.info[key] for transition in self.transitions]
|
||||||
|
return self._info[key]
|
||||||
|
|
||||||
def info(self, key, expand_dims=False) -> np.ndarray:
|
def info(self, key, expand_dims=False) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
if the given info dictionary key was not converted to a batch before, extract it to a batch and then return the
|
if the given info dictionary key was not converted to a batch before, extract it to a batch and then return the
|
||||||
@@ -537,11 +552,11 @@ class Batch(object):
|
|||||||
:param expand_dims: add an extra dimension to the info batch
|
:param expand_dims: add an extra dimension to the info batch
|
||||||
:return: a numpy array containing all the info values of the batch corresponding to the given key
|
:return: a numpy array containing all the info values of the batch corresponding to the given key
|
||||||
"""
|
"""
|
||||||
if key not in self._info.keys():
|
info_list = self.info_as_list(key)
|
||||||
self._info[key] = np.array([transition.info[key] for transition in self.transitions])
|
|
||||||
if expand_dims:
|
if expand_dims:
|
||||||
return np.expand_dims(self._info[key], -1)
|
return np.expand_dims(info_list, -1)
|
||||||
return self._info[key]
|
return np.array(info_list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def size(self) -> int:
|
def size(self) -> int:
|
||||||
@@ -572,6 +587,7 @@ class TotalStepsCounter(object):
|
|||||||
"""
|
"""
|
||||||
A wrapper around a dictionary counting different StepMethods steps done.
|
A wrapper around a dictionary counting different StepMethods steps done.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.counters = {
|
self.counters = {
|
||||||
EnvironmentEpisodes: 0,
|
EnvironmentEpisodes: 0,
|
||||||
@@ -619,7 +635,6 @@ class Episode(object):
|
|||||||
"""
|
"""
|
||||||
self.transitions = []
|
self.transitions = []
|
||||||
# a num_transitions x num_transitions table with the n step return in the n'th row
|
# a num_transitions x num_transitions table with the n step return in the n'th row
|
||||||
self.returns_table = None
|
|
||||||
self._length = 0
|
self._length = 0
|
||||||
self.discount = discount
|
self.discount = discount
|
||||||
self.bootstrap_total_return_from_old_policy = bootstrap_total_return_from_old_policy
|
self.bootstrap_total_return_from_old_policy = bootstrap_total_return_from_old_policy
|
||||||
@@ -650,28 +665,48 @@ class Episode(object):
|
|||||||
def get_first_transition(self):
|
def get_first_transition(self):
|
||||||
return self.get_transition(0) if self.length() > 0 else None
|
return self.get_transition(0) if self.length() > 0 else None
|
||||||
|
|
||||||
def update_returns(self):
|
def update_discounted_rewards(self):
|
||||||
if self.n_step == -1 or self.n_step > self.length():
|
if self.n_step == -1 or self.n_step > self.length():
|
||||||
curr_n_step = self.length()
|
curr_n_step = self.length()
|
||||||
else:
|
else:
|
||||||
curr_n_step = self.n_step
|
curr_n_step = self.n_step
|
||||||
|
|
||||||
rewards = np.array([t.reward for t in self.transitions])
|
rewards = np.array([t.reward for t in self.transitions])
|
||||||
rewards = rewards.astype('float')
|
rewards = rewards.astype('float')
|
||||||
total_return = rewards.copy()
|
discounted_rewards = rewards.copy()
|
||||||
current_discount = self.discount
|
current_discount = self.discount
|
||||||
for i in range(1, curr_n_step):
|
for i in range(1, curr_n_step):
|
||||||
total_return += current_discount * np.pad(rewards[i:], (0, i), 'constant', constant_values=0)
|
discounted_rewards += current_discount * np.pad(rewards[i:], (0, i), 'constant', constant_values=0)
|
||||||
current_discount *= self.discount
|
current_discount *= self.discount
|
||||||
|
|
||||||
# calculate the bootstrapped returns
|
# calculate the bootstrapped returns
|
||||||
if self.bootstrap_total_return_from_old_policy:
|
if self.bootstrap_total_return_from_old_policy:
|
||||||
bootstraps = np.array([np.squeeze(t.info['max_action_value']) for t in self.transitions[curr_n_step:]])
|
bootstraps = np.array([np.squeeze(t.info['max_action_value']) for t in self.transitions[curr_n_step:]])
|
||||||
bootstrapped_return = total_return + current_discount * np.pad(bootstraps, (0, curr_n_step), 'constant',
|
bootstrapped_return = discounted_rewards + current_discount * np.pad(bootstraps, (0, curr_n_step),
|
||||||
constant_values=0)
|
'constant', constant_values=0)
|
||||||
total_return = bootstrapped_return
|
discounted_rewards = bootstrapped_return
|
||||||
|
|
||||||
for transition_idx in range(self.length()):
|
for transition_idx in range(self.length()):
|
||||||
self.transitions[transition_idx].total_return = total_return[transition_idx]
|
self.transitions[transition_idx].n_step_discounted_rewards = discounted_rewards[transition_idx]
|
||||||
|
|
||||||
|
def update_transitions_rewards_and_bootstrap_data(self):
|
||||||
|
if not isinstance(self.n_step, int) or (self.n_step < 1 and self.n_step != -1):
|
||||||
|
raise ValueError("n-step should be an integer with value >= 1, or set to -1 for always setting to episode"
|
||||||
|
" length.")
|
||||||
|
elif self.n_step > 1:
|
||||||
|
curr_n_step = self.n_step if self.n_step < self.length() else self.length()
|
||||||
|
|
||||||
|
for idx, transition in enumerate(self.transitions):
|
||||||
|
next_n_step_transition_idx = (idx + curr_n_step)
|
||||||
|
if next_n_step_transition_idx < len(self.transitions):
|
||||||
|
# next state will now point to the n-step next state
|
||||||
|
transition.next_state = self.transitions[next_n_step_transition_idx].state
|
||||||
|
transition.info['should_bootstrap_next_state'] = True
|
||||||
|
else:
|
||||||
|
transition.next_state = self.transitions[-1].next_state
|
||||||
|
transition.info['should_bootstrap_next_state'] = False
|
||||||
|
|
||||||
|
self.update_discounted_rewards()
|
||||||
|
|
||||||
def update_actions_probabilities(self):
|
def update_actions_probabilities(self):
|
||||||
probability_product = 1
|
probability_product = 1
|
||||||
@@ -681,12 +716,6 @@ class Episode(object):
|
|||||||
for transition_idx, transition in enumerate(self.transitions):
|
for transition_idx, transition in enumerate(self.transitions):
|
||||||
transition.info['probability_product'] = probability_product
|
transition.info['probability_product'] = probability_product
|
||||||
|
|
||||||
def get_returns_table(self):
|
|
||||||
return self.returns_table
|
|
||||||
|
|
||||||
def get_returns(self):
|
|
||||||
return self.get_transitions_attribute('total_return')
|
|
||||||
|
|
||||||
def get_transitions_attribute(self, attribute_name):
|
def get_transitions_attribute(self, attribute_name):
|
||||||
if len(self.transitions) > 0 and hasattr(self.transitions[0], attribute_name):
|
if len(self.transitions) > 0 and hasattr(self.transitions[0], attribute_name):
|
||||||
return [getattr(t, attribute_name) for t in self.transitions]
|
return [getattr(t, attribute_name) for t in self.transitions]
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class EpisodicExperienceReplayParameters(MemoryParameters):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_size = (MemoryGranularity.Transitions, 1000000)
|
self.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||||
|
self.n_step = -1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self):
|
def path(self):
|
||||||
@@ -39,12 +40,13 @@ class EpisodicExperienceReplay(Memory):
|
|||||||
calculations of total return and other values that depend on the sequential behavior of the transitions
|
calculations of total return and other values that depend on the sequential behavior of the transitions
|
||||||
in the episode.
|
in the episode.
|
||||||
"""
|
"""
|
||||||
def __init__(self, max_size: Tuple[MemoryGranularity, int]):
|
def __init__(self, max_size: Tuple[MemoryGranularity, int]=(MemoryGranularity.Transitions, 1000000), n_step=-1):
|
||||||
"""
|
"""
|
||||||
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
||||||
"""
|
"""
|
||||||
super().__init__(max_size)
|
super().__init__(max_size)
|
||||||
self._buffer = [Episode()] # list of episodes
|
self.n_step = n_step
|
||||||
|
self._buffer = [Episode(n_step=self.n_step)] # list of episodes
|
||||||
self.transitions = []
|
self.transitions = []
|
||||||
self._length = 1 # the episodic replay buffer starts with a single empty episode
|
self._length = 1 # the episodic replay buffer starts with a single empty episode
|
||||||
self._num_transitions = 0
|
self._num_transitions = 0
|
||||||
@@ -109,7 +111,7 @@ class EpisodicExperienceReplay(Memory):
|
|||||||
self._remove_episode(0)
|
self._remove_episode(0)
|
||||||
|
|
||||||
def _update_episode(self, episode: Episode) -> None:
|
def _update_episode(self, episode: Episode) -> None:
|
||||||
episode.update_returns()
|
episode.update_transitions_rewards_and_bootstrap_data()
|
||||||
|
|
||||||
def verify_last_episode_is_closed(self) -> None:
|
def verify_last_episode_is_closed(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -138,7 +140,7 @@ class EpisodicExperienceReplay(Memory):
|
|||||||
self._length += 1
|
self._length += 1
|
||||||
|
|
||||||
# create a new Episode for the next transitions to be placed into
|
# create a new Episode for the next transitions to be placed into
|
||||||
self._buffer.append(Episode())
|
self._buffer.append(Episode(n_step=self.n_step))
|
||||||
|
|
||||||
# if update episode adds to the buffer, a new Episode needs to be ready first
|
# if update episode adds to the buffer, a new Episode needs to be ready first
|
||||||
# it would be better if this were less state full
|
# it would be better if this were less state full
|
||||||
@@ -158,12 +160,14 @@ class EpisodicExperienceReplay(Memory):
|
|||||||
:param transition: a transition to store
|
:param transition: a transition to store
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
||||||
super().store(transition)
|
super().store(transition)
|
||||||
|
|
||||||
self.reader_writer_lock.lock_writing_and_reading()
|
self.reader_writer_lock.lock_writing_and_reading()
|
||||||
|
|
||||||
if len(self._buffer) == 0:
|
if len(self._buffer) == 0:
|
||||||
self._buffer.append(Episode())
|
self._buffer.append(Episode(n_step=self.n_step))
|
||||||
last_episode = self._buffer[-1]
|
last_episode = self._buffer[-1]
|
||||||
last_episode.insert(transition)
|
last_episode.insert(transition)
|
||||||
self.transitions.append(transition)
|
self.transitions.append(transition)
|
||||||
@@ -284,7 +288,7 @@ class EpisodicExperienceReplay(Memory):
|
|||||||
self.reader_writer_lock.lock_writing_and_reading()
|
self.reader_writer_lock.lock_writing_and_reading()
|
||||||
|
|
||||||
self.transitions = []
|
self.transitions = []
|
||||||
self._buffer = [Episode()]
|
self._buffer = [Episode(n_step=self.n_step)]
|
||||||
self._length = 1
|
self._length = 1
|
||||||
self._num_transitions = 0
|
self._num_transitions = 0
|
||||||
self._num_transitions_in_complete_episodes = 0
|
self._num_transitions_in_complete_episodes = 0
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay):
|
|||||||
hindsight_transition.reward, hindsight_transition.game_over = \
|
hindsight_transition.reward, hindsight_transition.game_over = \
|
||||||
self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state)
|
self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state)
|
||||||
|
|
||||||
hindsight_transition.total_return = None
|
hindsight_transition.n_step_discounted_rewards = None
|
||||||
episode.insert(hindsight_transition)
|
episode.insert(hindsight_transition)
|
||||||
|
|
||||||
super().store_episode(episode)
|
super().store_episode(episode)
|
||||||
|
|||||||
@@ -128,14 +128,14 @@ class SegmentTree(object):
|
|||||||
self.tree[node_idx] = new_val
|
self.tree[node_idx] = new_val
|
||||||
self._propagate(node_idx)
|
self._propagate(node_idx)
|
||||||
|
|
||||||
def get(self, val: float) -> Tuple[int, float, Any]:
|
def get_element_by_partial_sum(self, val: float) -> Tuple[int, float, Any]:
|
||||||
"""
|
"""
|
||||||
Given a value between 0 and the tree sum, return the object which this value is in it's range.
|
Given a value between 0 and the tree sum, return the object which this value is in it's range.
|
||||||
For example, if we have 3 leaves: 10, 20, 30, and val=35, this will return the 3rd leaf, by accumulating
|
For example, if we have 3 leaves: 10, 20, 30, and val=35, this will return the 3rd leaf, by accumulating
|
||||||
leaves by their order until getting to 35. This allows sampling leaves according to their proportional
|
leaves by their order until getting to 35. This allows sampling leaves according to their proportional
|
||||||
probability.
|
probability.
|
||||||
:param val: a value within the range 0 and the tree sum
|
:param val: a value within the range 0 and the tree sum
|
||||||
:return: the index of the resulting leaf in the tree, it's probability and
|
:return: the index of the resulting leaf in the tree, its probability and
|
||||||
the object itself
|
the object itself
|
||||||
"""
|
"""
|
||||||
node_idx = self._retrieve(0, val)
|
node_idx = self._retrieve(0, val)
|
||||||
@@ -237,12 +237,12 @@ class PrioritizedExperienceReplay(ExperienceReplay):
|
|||||||
|
|
||||||
# sample a batch
|
# sample a batch
|
||||||
for i in range(size):
|
for i in range(size):
|
||||||
start_probability = segment_size * i
|
segment_start = segment_size * i
|
||||||
end_probability = segment_size * (i + 1)
|
segment_end = segment_size * (i + 1)
|
||||||
|
|
||||||
# sample leaf and calculate its weight
|
# sample leaf and calculate its weight
|
||||||
val = random.uniform(start_probability, end_probability)
|
val = random.uniform(segment_start, segment_end)
|
||||||
leaf_idx, priority, transition = self.sum_tree.get(val)
|
leaf_idx, priority, transition = self.sum_tree.get_element_by_partial_sum(val)
|
||||||
priority /= self.sum_tree.total_value() # P(j) = p^a / sum(p^a)
|
priority /= self.sum_tree.total_value() # P(j) = p^a / sum(p^a)
|
||||||
weight = (self.num_transitions() * priority) ** -self.beta.current_value # (N * P(j)) ^ -beta
|
weight = (self.num_transitions() * priority) ** -self.beta.current_value # (N * P(j)) ^ -beta
|
||||||
normalized_weight = weight / max_weight # wj = ((N * P(j)) ^ -beta) / max wi
|
normalized_weight = weight / max_weight # wj = ((N * P(j)) ^ -beta) / max wi
|
||||||
@@ -261,7 +261,7 @@ class PrioritizedExperienceReplay(ExperienceReplay):
|
|||||||
self.reader_writer_lock.release_writing()
|
self.reader_writer_lock.release_writing()
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def store(self, transition: Transition) -> None:
|
def store(self, transition: Transition, lock=True) -> None:
|
||||||
"""
|
"""
|
||||||
Store a new transition in the memory.
|
Store a new transition in the memory.
|
||||||
:param transition: a transition to store
|
:param transition: a transition to store
|
||||||
@@ -270,7 +270,8 @@ class PrioritizedExperienceReplay(ExperienceReplay):
|
|||||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
||||||
super().store(transition)
|
super().store(transition)
|
||||||
|
|
||||||
self.reader_writer_lock.lock_writing_and_reading()
|
if lock:
|
||||||
|
self.reader_writer_lock.lock_writing_and_reading()
|
||||||
|
|
||||||
transition_priority = self.maximal_priority
|
transition_priority = self.maximal_priority
|
||||||
self.sum_tree.add(transition_priority ** self.alpha, transition)
|
self.sum_tree.add(transition_priority ** self.alpha, transition)
|
||||||
@@ -278,18 +279,21 @@ class PrioritizedExperienceReplay(ExperienceReplay):
|
|||||||
self.max_tree.add(transition_priority, transition)
|
self.max_tree.add(transition_priority, transition)
|
||||||
super().store(transition, False)
|
super().store(transition, False)
|
||||||
|
|
||||||
self.reader_writer_lock.release_writing_and_reading()
|
if lock:
|
||||||
|
self.reader_writer_lock.release_writing_and_reading()
|
||||||
|
|
||||||
def clean(self) -> None:
|
def clean(self, lock=True) -> None:
|
||||||
"""
|
"""
|
||||||
Clean the memory by removing all the episodes
|
Clean the memory by removing all the episodes
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
self.reader_writer_lock.lock_writing_and_reading()
|
if lock:
|
||||||
|
self.reader_writer_lock.lock_writing_and_reading()
|
||||||
|
|
||||||
super().clean(lock=False)
|
super().clean(lock=False)
|
||||||
self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM)
|
self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM)
|
||||||
self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN)
|
self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN)
|
||||||
self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX)
|
self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX)
|
||||||
|
|
||||||
self.reader_writer_lock.release_writing_and_reading()
|
if lock:
|
||||||
|
self.reader_writer_lock.release_writing_and_reading()
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ env_params = GymVectorEnvironment(level='CartPole-v0')
|
|||||||
preset_validation_params = PresetValidationParameters()
|
preset_validation_params = PresetValidationParameters()
|
||||||
preset_validation_params.test = True
|
preset_validation_params.test = True
|
||||||
preset_validation_params.min_reward_threshold = 150
|
preset_validation_params.min_reward_threshold = 150
|
||||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
preset_validation_params.max_episodes_to_achieve_reward = 400
|
||||||
|
|
||||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||||
@@ -26,10 +26,10 @@ def test_sum_tree():
|
|||||||
sum_tree.add(5, "5")
|
sum_tree.add(5, "5")
|
||||||
assert sum_tree.total_value() == 20
|
assert sum_tree.total_value() == 20
|
||||||
|
|
||||||
assert sum_tree.get(2) == (0, 2.5, '2.5')
|
assert sum_tree.get_element_by_partial_sum(2) == (0, 2.5, '2.5')
|
||||||
assert sum_tree.get(3) == (1, 5.0, '5')
|
assert sum_tree.get_element_by_partial_sum(3) == (1, 5.0, '5')
|
||||||
assert sum_tree.get(10) == (2, 5.0, '5')
|
assert sum_tree.get_element_by_partial_sum(10) == (2, 5.0, '5')
|
||||||
assert sum_tree.get(13) == (3, 7.5, '7.5')
|
assert sum_tree.get_element_by_partial_sum(13) == (3, 7.5, '7.5')
|
||||||
|
|
||||||
sum_tree.update(2, 10)
|
sum_tree.update(2, 10)
|
||||||
assert sum_tree.__str__() == "[25.]\n[ 7.5 17.5]\n[ 2.5 5. 10. 7.5]\n"
|
assert sum_tree.__str__() == "[25.]\n[ 7.5 17.5]\n[ 2.5 5. 10. 7.5]\n"
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ def test_store_and_get(buffer: SingleEpisodeBuffer):
|
|||||||
# check that the episode is valid
|
# check that the episode is valid
|
||||||
episode = buffer.get(0)
|
episode = buffer.get(0)
|
||||||
assert episode.length() == 2
|
assert episode.length() == 2
|
||||||
assert episode.get_transition(0).total_return == 1 + 0.99
|
assert episode.get_transition(0).n_step_discounted_rewards == 1 + 0.99
|
||||||
assert episode.get_transition(1).total_return == 1
|
assert episode.get_transition(1).n_step_discounted_rewards == 1
|
||||||
assert buffer.mean_reward() == 1
|
assert buffer.mean_reward() == 1
|
||||||
|
|
||||||
# only one episode in the replay buffer
|
# only one episode in the replay buffer
|
||||||
|
|||||||
Reference in New Issue
Block a user