mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
fix keep_dims -> keepdims
This commit is contained in:
@@ -16,7 +16,8 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agents.value_optimization_agent import *
|
||||
from agents.value_optimization_agent import ValueOptimizationAgent
|
||||
from logger import screen
|
||||
|
||||
|
||||
# Neural Episodic Control - https://arxiv.org/pdf/1703.01988.pdf
|
||||
|
||||
@@ -112,8 +112,12 @@ class PPOAgent(ActorCriticAgent):
|
||||
current_values = self.critic_network.online_network.predict(current_states_batch)
|
||||
targets = current_values * (1 - mix_fraction) + total_return_batch * mix_fraction
|
||||
|
||||
inputs = copy.copy(current_states_batch)
|
||||
for input_index, input in enumerate(old_policy_values):
|
||||
inputs['output_0_{}'.format(input_index)] = input
|
||||
|
||||
value_loss = self.critic_network.online_network.\
|
||||
accumulate_gradients([current_states_batch] + old_policy_values, targets)
|
||||
accumulate_gradients(inputs, targets)
|
||||
self.critic_network.apply_gradients_to_online_network()
|
||||
if self.tp.distributed:
|
||||
self.critic_network.apply_gradients_to_global_network()
|
||||
|
||||
Reference in New Issue
Block a user