mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Enabling-more-agents-for-Batch-RL-and-cleanup (#258)
allowing for the last training batch drawn to be smaller than batch_size + adding support for more agents in BatchRL by adding softmax with temperature to the corresponding heads + adding a CartPole_QR_DQN preset with a golden test + cleanups
This commit is contained in:
@@ -84,8 +84,8 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
# prediction's format is (batch,actions,atoms)
|
||||
def get_all_q_values_for_states(self, states: StateType):
|
||||
if self.exploration_policy.requires_action_values():
|
||||
prediction = self.get_prediction(states)
|
||||
q_values = self.distribution_prediction_to_q_values(prediction)
|
||||
q_values = self.get_prediction(states,
|
||||
outputs=self.networks['main'].online_network.output_heads[0].q_values)
|
||||
else:
|
||||
q_values = None
|
||||
return q_values
|
||||
@@ -105,9 +105,9 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
|
||||
# select the optimal actions for the next state
|
||||
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
|
||||
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||
m = np.zeros((batch.size, self.z_values.size))
|
||||
|
||||
batches = np.arange(self.ap.network_wrappers['main'].batch_size)
|
||||
batches = np.arange(batch.size)
|
||||
|
||||
# an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.
|
||||
# only 10% speedup overall - leaving commented out as the code is not as clear.
|
||||
@@ -120,7 +120,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
# bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])
|
||||
# u_ = (np.ceil(bj_)).astype(int)
|
||||
# l_ = (np.floor(bj_)).astype(int)
|
||||
# m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||
# m_ = np.zeros((batch.size, self.z_values.size))
|
||||
# np.add.at(m_, [batches, l_],
|
||||
# np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))
|
||||
# np.add.at(m_, [batches, u_],
|
||||
|
||||
Reference in New Issue
Block a user