mirror of
https://github.com/gryf/coach.git
synced 2026-02-14 21:15:53 +01:00
Enabling-more-agents-for-Batch-RL-and-cleanup (#258)
allowing for the last training batch drawn to be smaller than batch_size + adding support for more agents in BatchRL by adding softmax with temperature to the corresponding heads + adding a CartPole_QR_DQN preset with a golden test + cleanups
This commit is contained in:
@@ -18,11 +18,9 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters
|
||||
from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAgentParameters
|
||||
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.exploration_policies.bootstrapped import BootstrappedParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
||||
|
||||
|
||||
class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
|
||||
@@ -32,12 +30,11 @@ class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
|
||||
self.heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/self.heads_parameters[0].num_output_head_copies
|
||||
|
||||
|
||||
class BootstrappedDQNAgentParameters(AgentParameters):
|
||||
class BootstrappedDQNAgentParameters(DQNAgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=DQNAlgorithmParameters(),
|
||||
exploration=BootstrappedParameters(),
|
||||
memory=ExperienceReplayParameters(),
|
||||
networks={"main": BootstrappedDQNNetworkParameters()})
|
||||
super().__init__()
|
||||
self.exploration = BootstrappedParameters()
|
||||
self.network_wrappers = {"main": BootstrappedDQNNetworkParameters()}
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
@@ -65,13 +62,14 @@ class BootstrappedDQNAgent(ValueOptimizationAgent):
|
||||
TD_targets = result[self.ap.exploration.architecture_num_q_heads:]
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(TD_targets)
|
||||
|
||||
# initialize with the current prediction so that we will
|
||||
# only update the action that we have actually done in this transition
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
for i in range(batch.size):
|
||||
mask = batch[i].info['mask']
|
||||
for head_idx in range(self.ap.exploration.architecture_num_q_heads):
|
||||
self.q_values.add_sample(TD_targets[head_idx])
|
||||
|
||||
if mask[head_idx] == 1:
|
||||
selected_action = np.argmax(next_states_online_values[head_idx][i], 0)
|
||||
TD_targets[head_idx][i, batch.actions()[i]] = \
|
||||
|
||||
@@ -84,8 +84,8 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
# prediction's format is (batch,actions,atoms)
|
||||
def get_all_q_values_for_states(self, states: StateType):
|
||||
if self.exploration_policy.requires_action_values():
|
||||
prediction = self.get_prediction(states)
|
||||
q_values = self.distribution_prediction_to_q_values(prediction)
|
||||
q_values = self.get_prediction(states,
|
||||
outputs=self.networks['main'].online_network.output_heads[0].q_values)
|
||||
else:
|
||||
q_values = None
|
||||
return q_values
|
||||
@@ -105,9 +105,9 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
|
||||
# select the optimal actions for the next state
|
||||
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
|
||||
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||
m = np.zeros((batch.size, self.z_values.size))
|
||||
|
||||
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
|
||||
batches = np.arange(batch.size)
|
||||
|
||||
# an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.
|
||||
# only 10% speedup overall - leaving commented out as the code is not as clear.
|
||||
@@ -120,7 +120,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
# bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])
|
||||
# u_ = (np.ceil(bj_)).astype(int)
|
||||
# l_ = (np.floor(bj_)).astype(int)
|
||||
# m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||
# m_ = np.zeros((batch.size, self.z_values.size))
|
||||
# np.add.at(m_, [batches, l_],
|
||||
# np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))
|
||||
# np.add.at(m_, [batches, u_],
|
||||
|
||||
@@ -207,6 +207,8 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
self.networks['main'].online_network.output_heads[1].likelihood_ratio,
|
||||
self.networks['main'].online_network.output_heads[1].clipped_likelihood_ratio]
|
||||
|
||||
# TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on
|
||||
# some of the data
|
||||
for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size)):
|
||||
start = i * self.ap.network_wrappers['main'].batch_size
|
||||
end = (i + 1) * self.ap.network_wrappers['main'].batch_size
|
||||
|
||||
@@ -56,7 +56,7 @@ class DDQNAgent(ValueOptimizationAgent):
|
||||
# initialize with the current prediction so that we will
|
||||
# only update the action that we have actually done in this transition
|
||||
TD_errors = []
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
for i in range(batch.size):
|
||||
new_target = batch.rewards()[i] + \
|
||||
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * q_st_plus_1[i][selected_actions[i]]
|
||||
TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]]))
|
||||
|
||||
@@ -146,13 +146,13 @@ class DFPAgent(Agent):
|
||||
|
||||
network_inputs = batch.states(network_keys)
|
||||
network_inputs['goal'] = np.repeat(np.expand_dims(self.current_goal, 0),
|
||||
self.ap.network_wrappers['main'].batch_size, axis=0)
|
||||
batch.size, axis=0)
|
||||
|
||||
# get the current outputs of the network
|
||||
targets = self.networks['main'].online_network.predict(network_inputs)
|
||||
|
||||
# change the targets for the taken actions
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
for i in range(batch.size):
|
||||
targets[i, batch.actions()[i]] = batch[i].info['future_measurements'].flatten()
|
||||
|
||||
result = self.networks['main'].train_and_sync_networks(network_inputs, targets)
|
||||
|
||||
@@ -86,7 +86,7 @@ class DQNAgent(ValueOptimizationAgent):
|
||||
|
||||
# only update the action that we have actually done in this transition
|
||||
TD_errors = []
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
for i in range(batch.size):
|
||||
new_target = batch.rewards()[i] +\
|
||||
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * np.max(q_st_plus_1[i], 0)
|
||||
TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]]))
|
||||
|
||||
@@ -65,7 +65,7 @@ class MixedMonteCarloAgent(ValueOptimizationAgent):
|
||||
|
||||
total_returns = batch.n_step_discounted_rewards()
|
||||
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
for i in range(batch.size):
|
||||
one_step_target = batch.rewards()[i] + \
|
||||
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
|
||||
q_st_plus_1[i][selected_actions[i]]
|
||||
|
||||
@@ -133,7 +133,7 @@ class NECAgent(ValueOptimizationAgent):
|
||||
TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys))
|
||||
bootstrapped_return_from_old_policy = batch.n_step_discounted_rewards()
|
||||
# only update the action that we have actually done in this transition
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
for i in range(batch.size):
|
||||
TD_targets[i, batch.actions()[i]] = bootstrapped_return_from_old_policy[i]
|
||||
|
||||
# set the gradients to fetch for the DND update
|
||||
|
||||
@@ -84,7 +84,7 @@ class PALAgent(ValueOptimizationAgent):
|
||||
# calculate TD error
|
||||
TD_targets = np.copy(q_st_online)
|
||||
total_returns = batch.n_step_discounted_rewards()
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
for i in range(batch.size):
|
||||
TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \
|
||||
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
|
||||
q_st_plus_1_target[i][selected_actions[i]]
|
||||
|
||||
@@ -95,7 +95,7 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
|
||||
target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1)
|
||||
|
||||
# calculate the Bellman update
|
||||
batch_idx = list(range(self.ap.network_wrappers['main'].batch_size))
|
||||
batch_idx = list(range(batch.size))
|
||||
|
||||
TD_targets = batch.rewards(True) + (1.0 - batch.game_overs(True)) * self.ap.algorithm.discount \
|
||||
* next_state_quantiles[batch_idx, target_actions]
|
||||
@@ -106,9 +106,9 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
|
||||
# calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order
|
||||
cumulative_probabilities = np.array(range(self.ap.algorithm.atoms + 1)) / float(self.ap.algorithm.atoms) # tau_i
|
||||
quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i
|
||||
quantile_midpoints = np.tile(quantile_midpoints, (self.ap.network_wrappers['main'].batch_size, 1))
|
||||
quantile_midpoints = np.tile(quantile_midpoints, (batch.size, 1))
|
||||
sorted_quantiles = np.argsort(current_quantiles[batch_idx, batch.actions()])
|
||||
for idx in range(self.ap.network_wrappers['main'].batch_size):
|
||||
for idx in range(batch.size):
|
||||
quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]]
|
||||
|
||||
# train
|
||||
|
||||
@@ -103,9 +103,9 @@ class RainbowDQNAgent(CategoricalDQNAgent):
|
||||
|
||||
# only update the action that we have actually done in this transition (using the Double-DQN selected actions)
|
||||
target_actions = ddqn_selected_actions
|
||||
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||
m = np.zeros((batch.size, self.z_values.size))
|
||||
|
||||
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
|
||||
batches = np.arange(batch.size)
|
||||
for j in range(self.z_values.size):
|
||||
# we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step,
|
||||
# we will not bootstrap for the last n-step transitions in the episode
|
||||
|
||||
@@ -50,8 +50,9 @@ class ValueOptimizationAgent(Agent):
|
||||
actions_q_values = None
|
||||
return actions_q_values
|
||||
|
||||
def get_prediction(self, states):
|
||||
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'))
|
||||
def get_prediction(self, states, outputs=None):
|
||||
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'),
|
||||
outputs=outputs)
|
||||
|
||||
def update_transition_priorities_and_get_weights(self, TD_errors, batch):
|
||||
# update errors in prioritized replay buffer
|
||||
@@ -151,17 +152,18 @@ class ValueOptimizationAgent(Agent):
|
||||
# this is fitted from the training dataset
|
||||
for epoch in range(epochs):
|
||||
loss = 0
|
||||
total_transitions_processed = 0
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||
batch = Batch(batch)
|
||||
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
|
||||
current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards()
|
||||
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
|
||||
loss += self.networks['reward_model'].train_and_sync_networks(
|
||||
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
|
||||
# print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0])
|
||||
total_transitions_processed += batch.size
|
||||
|
||||
log = OrderedDict()
|
||||
log['Epoch'] = epoch
|
||||
log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size)
|
||||
log['loss'] = loss / total_transitions_processed
|
||||
screen.log_dict(log, prefix='Training Reward Model')
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user