diff --git a/agents/actor_critic_agent.py b/agents/actor_critic_agent.py index e2dc916..729e67f 100644 --- a/agents/actor_critic_agent.py +++ b/agents/actor_critic_agent.py @@ -101,9 +101,7 @@ class ActorCriticAgent(PolicyOptimizationAgent): actions = np.expand_dims(actions, -1) # train - inputs = copy.copy(current_states) - inputs['output_1_0'] = actions - result = self.main_network.online_network.accumulate_gradients(inputs, + result = self.main_network.online_network.accumulate_gradients({**current_states, 'output_1_0': actions}, [state_value_head_targets, action_advantages]) # logging diff --git a/agents/qr_dqn_agent.py b/agents/qr_dqn_agent.py index 08ecdac..8888d18 100644 --- a/agents/qr_dqn_agent.py +++ b/agents/qr_dqn_agent.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -56,8 +56,11 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent): quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]] # train - result = self.main_network.train_and_sync_networks([current_states, actions_locations, quantile_midpoints], TD_targets) + result = self.main_network.train_and_sync_networks({ + **current_states, + 'output_0_0': actions_locations, + 'output_0_1': quantile_midpoints, + }, TD_targets) total_loss = result[0] return total_loss -