1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

update nec and value optimization agents to work with recurrent middleware

This commit is contained in:
Zach Dwiel
2017-11-03 13:58:42 -07:00
parent 93a54c7e8e
commit 6c79a442f2
12 changed files with 138 additions and 72 deletions

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -67,6 +67,10 @@ class Head(object):
def _build_module(self, input_layer):
"""
Builds the graph of the module
This method is called early on from __call__. It is expected to store the graph
in self.output.
:param input_layer: the input to the graph
:return: None
"""
@@ -279,20 +283,26 @@ class DNDQHead(Head):
key_error_threshold=self.DND_key_error_threshold)
# Retrieve info from DND dictionary
self.action = tf.placeholder(tf.int8, [None], name="action")
self.input = self.action
# self.action = tf.placeholder(tf.int8, [None], name="action")
# self.input = self.action
self.output = [
self._q_value(input_layer, action)
for action in range(self.num_actions)
]
def _q_value(self, input_layer, action):
result = tf.py_func(self.DND.query,
[input_layer, self.action, self.number_of_nn],
[input_layer, [action], self.number_of_nn],
[tf.float64, tf.float64])
self.dnd_embeddings = tf.to_float(result[0])
self.dnd_values = tf.to_float(result[1])
dnd_embeddings = tf.to_float(result[0])
dnd_values = tf.to_float(result[1])
# DND calculation
square_diff = tf.square(self.dnd_embeddings - tf.expand_dims(input_layer, 1))
square_diff = tf.square(dnd_embeddings - tf.expand_dims(input_layer, 1))
distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta]
weights = 1.0 / distances
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True)
self.output = tf.reduce_sum(self.dnd_values * normalised_weights, axis=1)
return tf.reduce_sum(dnd_values * normalised_weights, axis=1)
class NAFHead(Head):