mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Enabling-more-agents-for-Batch-RL-and-cleanup (#258)
allowing for the last training batch drawn to be smaller than batch_size + adding support for more agents in BatchRL by adding softmax with temperature to the corresponding heads + adding a CartPole_QR_DQN preset with a golden test + cleanups
This commit is contained in:
@@ -50,8 +50,9 @@ class ValueOptimizationAgent(Agent):
|
||||
actions_q_values = None
|
||||
return actions_q_values
|
||||
|
||||
def get_prediction(self, states):
|
||||
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'))
|
||||
def get_prediction(self, states, outputs=None):
|
||||
return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'),
|
||||
outputs=outputs)
|
||||
|
||||
def update_transition_priorities_and_get_weights(self, TD_errors, batch):
|
||||
# update errors in prioritized replay buffer
|
||||
@@ -151,17 +152,18 @@ class ValueOptimizationAgent(Agent):
|
||||
# this is fitted from the training dataset
|
||||
for epoch in range(epochs):
|
||||
loss = 0
|
||||
total_transitions_processed = 0
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||
batch = Batch(batch)
|
||||
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
|
||||
current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards()
|
||||
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
|
||||
loss += self.networks['reward_model'].train_and_sync_networks(
|
||||
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
|
||||
# print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0])
|
||||
total_transitions_processed += batch.size
|
||||
|
||||
log = OrderedDict()
|
||||
log['Epoch'] = epoch
|
||||
log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size)
|
||||
log['loss'] = loss / total_transitions_processed
|
||||
screen.log_dict(log, prefix='Training Reward Model')
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user