1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Enabling-more-agents-for-Batch-RL-and-cleanup (#258)

allowing for the last training batch drawn to be smaller than batch_size + adding support for more agents in BatchRL by adding softmax with temperature to the corresponding heads + adding a CartPole_QR_DQN preset with a golden test + cleanups
This commit is contained in:
Gal Leibovich
2019-03-21 16:10:29 +02:00
committed by GitHub
parent abec59f367
commit 6e08c55ad5
24 changed files with 152 additions and 69 deletions

View File

@@ -84,8 +84,8 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# prediction's format is (batch,actions,atoms)
def get_all_q_values_for_states(self, states: StateType):
if self.exploration_policy.requires_action_values():
prediction = self.get_prediction(states)
q_values = self.distribution_prediction_to_q_values(prediction)
q_values = self.get_prediction(states,
outputs=self.networks['main'].online_network.output_heads[0].q_values)
else:
q_values = None
return q_values
@@ -105,9 +105,9 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# select the optimal actions for the next state
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
m = np.zeros((batch.size, self.z_values.size))
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
batches = np.arange(batch.size)
# an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.
# only 10% speedup overall - leaving commented out as the code is not as clear.
@@ -120,7 +120,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
# bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])
# u_ = (np.ceil(bj_)).astype(int)
# l_ = (np.floor(bj_)).astype(int)
# m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
# m_ = np.zeros((batch.size, self.z_values.size))
# np.add.at(m_, [batches, l_],
# np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))
# np.add.at(m_, [batches, u_],