mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
temp commit
This commit is contained in:
@@ -250,7 +250,7 @@ class MeasurementsPredictionHead(Head):
|
||||
name='output')
|
||||
action_stream = tf.reshape(action_stream,
|
||||
(tf.shape(action_stream)[0], self.num_actions, self.multi_step_measurements_size))
|
||||
action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keep_dims=True)
|
||||
action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keepdims=True)
|
||||
|
||||
# merge to future measurements predictions
|
||||
self.output = tf.add(expectation_stream, action_stream, name='output')
|
||||
@@ -302,7 +302,7 @@ class DNDQHead(Head):
|
||||
square_diff = tf.square(dnd_embeddings - tf.expand_dims(input_layer, 1))
|
||||
distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta]
|
||||
weights = 1.0 / distances
|
||||
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True)
|
||||
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keepdims=True)
|
||||
return tf.reduce_sum(dnd_values * normalised_weights, axis=1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user