mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
fix bug in ddpg
This commit is contained in:
@@ -108,8 +108,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
# gradients of the outputs w.r.t. the inputs
|
||||
# at the moment, this is only used by ddpg
|
||||
if len(self.outputs) == 1:
|
||||
# TODO: convert gradients_with_respect_to_inputs into dictionary?
|
||||
self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs.values()]
|
||||
self.gradients_wrt_inputs = {name: tf.gradients(self.outputs[0], input_ph) for name, input_ph in self.inputs.items()}
|
||||
self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights')
|
||||
self.weighted_gradients = tf.gradients(self.outputs[0], self.trainable_weights, self.gradients_weights_ph)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user