From 5cf10e5f52739c46041e232a02a5c3dacaa76f6f Mon Sep 17 00:00:00 2001 From: Zach Dwiel Date: Fri, 16 Feb 2018 20:18:03 -0500 Subject: [PATCH] fix bug in ddpg --- agents/ddpg_agent.py | 2 +- architectures/tensorflow_components/architecture.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/agents/ddpg_agent.py b/agents/ddpg_agent.py index b0eeaf6..425f1de 100644 --- a/agents/ddpg_agent.py +++ b/agents/ddpg_agent.py @@ -54,7 +54,7 @@ class DDPGAgent(ActorCriticAgent): actions_mean = self.actor_network.online_network.predict(current_states) critic_online_network = self.critic_network.online_network # TODO: convert into call to predict, current method ignores lstm middleware for example - action_gradients = self.critic_network.sess.run(critic_online_network.gradients_wrt_inputs[1], + action_gradients = self.critic_network.sess.run(critic_online_network.gradients_wrt_inputs['action'], feed_dict=critic_online_network._feed_dict({ **current_states, 'action': actions_mean, diff --git a/architectures/tensorflow_components/architecture.py b/architectures/tensorflow_components/architecture.py index 5008b32..5cc77f7 100644 --- a/architectures/tensorflow_components/architecture.py +++ b/architectures/tensorflow_components/architecture.py @@ -108,8 +108,7 @@ class TensorFlowArchitecture(Architecture): # gradients of the outputs w.r.t. the inputs # at the moment, this is only used by ddpg if len(self.outputs) == 1: - # TODO: convert gradients_with_respect_to_inputs into dictionary? - self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs.values()] + self.gradients_wrt_inputs = {name: tf.gradients(self.outputs[0], input_ph) for name, input_ph in self.inputs.items()} self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights') self.weighted_gradients = tf.gradients(self.outputs[0], self.trainable_weights, self.gradients_weights_ph)