1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

fix keep_dims -> keepdims

This commit is contained in:
Zach Dwiel
2018-02-16 13:30:31 -05:00
parent 39a28aba95
commit ee6e0bdc3b
3 changed files with 10 additions and 5 deletions

View File

@@ -112,8 +112,12 @@ class PPOAgent(ActorCriticAgent):
current_values = self.critic_network.online_network.predict(current_states_batch)
targets = current_values * (1 - mix_fraction) + total_return_batch * mix_fraction
inputs = copy.copy(current_states_batch)
for input_index, input in enumerate(old_policy_values):
inputs['output_0_{}'.format(input_index)] = input
value_loss = self.critic_network.online_network.\
accumulate_gradients([current_states_batch] + old_policy_values, targets)
accumulate_gradients(inputs, targets)
self.critic_network.apply_gradients_to_online_network()
if self.tp.distributed:
self.critic_network.apply_gradients_to_global_network()