1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Enabling-more-agents-for-Batch-RL-and-cleanup (#258)

allowing for the last training batch drawn to be smaller than batch_size + adding support for more agents in BatchRL by adding softmax with temperature to the corresponding heads + adding a CartPole_QR_DQN preset with a golden test + cleanups
This commit is contained in:
Gal Leibovich
2019-03-21 16:10:29 +02:00
committed by GitHub
parent abec59f367
commit 6e08c55ad5
24 changed files with 152 additions and 69 deletions

View File

@@ -50,8 +50,9 @@ class ValueOptimizationAgent(Agent):
actions_q_values = None
return actions_q_values
def get_prediction(self, states):
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'))
def get_prediction(self, states, outputs=None):
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'),
outputs=outputs)
def update_transition_priorities_and_get_weights(self, TD_errors, batch):
# update errors in prioritized replay buffer
@@ -151,17 +152,18 @@ class ValueOptimizationAgent(Agent):
# this is fitted from the training dataset
for epoch in range(epochs):
loss = 0
total_transitions_processed = 0
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
batch = Batch(batch)
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards()
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
loss += self.networks['reward_model'].train_and_sync_networks(
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
# print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0])
total_transitions_processed += batch.size
log = OrderedDict()
log['Epoch'] = epoch
log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size)
log['loss'] = loss / total_transitions_processed
screen.log_dict(log, prefix='Training Reward Model')