mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
update nec and value optimization agents to work with recurrent middleware
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -30,16 +30,19 @@ except ImportError:
|
||||
|
||||
|
||||
class NetworkWrapper(object):
|
||||
"""
|
||||
Contains multiple networks and managers syncing and gradient updates
|
||||
between them.
|
||||
"""
|
||||
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
|
||||
"""
|
||||
|
||||
:param tuning_parameters:
|
||||
:param tuning_parameters:
|
||||
:type tuning_parameters: Preset
|
||||
:param has_target:
|
||||
:param has_global:
|
||||
:param name:
|
||||
:param replicated_device:
|
||||
:param worker_device:
|
||||
:param has_target:
|
||||
:param has_global:
|
||||
:param name:
|
||||
:param replicated_device:
|
||||
:param worker_device:
|
||||
"""
|
||||
self.tp = tuning_parameters
|
||||
self.has_target = has_target
|
||||
@@ -87,7 +90,7 @@ class NetworkWrapper(object):
|
||||
def sync(self):
|
||||
"""
|
||||
Initializes the weights of the networks to match each other
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
self.update_online_network()
|
||||
self.update_target_network()
|
||||
@@ -111,14 +114,14 @@ class NetworkWrapper(object):
|
||||
def apply_gradients_to_global_network(self):
|
||||
"""
|
||||
Apply gradients from the online network on the global network
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
self.global_network.apply_gradients(self.online_network.accumulated_gradients)
|
||||
|
||||
def apply_gradients_to_online_network(self):
|
||||
"""
|
||||
Apply gradients from the online network on itself
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
|
||||
|
||||
@@ -135,7 +138,7 @@ class NetworkWrapper(object):
|
||||
|
||||
def apply_gradients_and_sync_networks(self):
|
||||
"""
|
||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
networks if necessary
|
||||
"""
|
||||
if self.global_network:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -64,7 +64,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
trainable=False)
|
||||
self.lock = self.lock_counter.assign_add(1, use_locking=True)
|
||||
self.lock_init = self.lock_counter.assign(0)
|
||||
|
||||
|
||||
self.release_counter = tf.get_variable("release_counter", [], tf.int32,
|
||||
initializer=tf.constant_initializer(0, dtype=tf.int32),
|
||||
trainable=False)
|
||||
@@ -86,6 +86,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
tuning_parameters.clip_gradients)
|
||||
|
||||
# gradients of the outputs w.r.t. the inputs
|
||||
# at the moment, this is only used by ddpg
|
||||
if len(self.outputs) == 1:
|
||||
self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs]
|
||||
self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights')
|
||||
@@ -126,7 +127,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
def accumulate_gradients(self, inputs, targets, additional_fetches=None):
|
||||
"""
|
||||
Runs a forward pass & backward pass, clips gradients if needed and accumulates them into the accumulation
|
||||
Runs a forward pass & backward pass, clips gradients if needed and accumulates them into the accumulation
|
||||
placeholders
|
||||
:param additional_fetches: Optional tensors to fetch during gradients calculation
|
||||
:param inputs: The input batch for the network
|
||||
@@ -164,6 +165,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
# feed the lstm state if necessary
|
||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||
# we can't always assume that we are starting from scratch here can we?
|
||||
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
||||
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
||||
|
||||
@@ -231,20 +233,27 @@ class TensorFlowArchitecture(Architecture):
|
||||
while self.tp.sess.run(self.release_counter) % self.tp.num_threads != 0:
|
||||
time.sleep(0.00001)
|
||||
|
||||
def predict(self, inputs):
|
||||
def predict(self, inputs, outputs=None):
|
||||
"""
|
||||
Run a forward pass of the network using the given input
|
||||
:param inputs: The input for the network
|
||||
:param outputs: The output for the network, defaults to self.outputs
|
||||
:return: The network output
|
||||
|
||||
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
||||
"""
|
||||
|
||||
feed_dict = dict(zip(self.inputs, force_list(inputs)))
|
||||
if outputs is None:
|
||||
outputs = self.outputs
|
||||
|
||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||
feed_dict[self.middleware_embedder.c_in] = self.curr_rnn_c_in
|
||||
feed_dict[self.middleware_embedder.h_in] = self.curr_rnn_h_in
|
||||
output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.tp.sess.run([self.outputs, self.middleware_embedder.state_out], feed_dict=feed_dict)
|
||||
|
||||
output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.tp.sess.run([outputs, self.middleware_embedder.state_out], feed_dict=feed_dict)
|
||||
else:
|
||||
output = self.tp.sess.run(self.outputs, feed_dict)
|
||||
output = self.tp.sess.run(outputs, feed_dict)
|
||||
|
||||
return squeeze_list(output)
|
||||
|
||||
@@ -299,7 +308,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
def set_variable_value(self, assign_op, value, placeholder=None):
|
||||
"""
|
||||
Updates the value of a variable.
|
||||
Updates the value of a variable.
|
||||
This requires having an assign operation for the variable, and a placeholder which will provide the value
|
||||
:param assign_op: an assign operation for the variable
|
||||
:param value: a value to set the variable to
|
||||
|
||||
@@ -22,6 +22,9 @@ from configurations import InputTypes, OutputTypes, MiddlewareTypes
|
||||
|
||||
|
||||
class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
"""
|
||||
A generalized version of all possible networks implemented using tensorflow.
|
||||
"""
|
||||
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
||||
self.global_network = global_network
|
||||
self.network_is_local = network_is_local
|
||||
@@ -79,7 +82,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
OutputTypes.DNDQ: DNDQHead,
|
||||
OutputTypes.NAF: NAFHead,
|
||||
OutputTypes.PPO: PPOHead,
|
||||
OutputTypes.PPO_V : PPOVHead,
|
||||
OutputTypes.PPO_V: PPOVHead,
|
||||
OutputTypes.CategoricalQ: CategoricalQHead,
|
||||
OutputTypes.QuantileRegressionQ: QuantileRegressionQHead
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -67,6 +67,10 @@ class Head(object):
|
||||
def _build_module(self, input_layer):
|
||||
"""
|
||||
Builds the graph of the module
|
||||
|
||||
This method is called early on from __call__. It is expected to store the graph
|
||||
in self.output.
|
||||
|
||||
:param input_layer: the input to the graph
|
||||
:return: None
|
||||
"""
|
||||
@@ -279,20 +283,26 @@ class DNDQHead(Head):
|
||||
key_error_threshold=self.DND_key_error_threshold)
|
||||
|
||||
# Retrieve info from DND dictionary
|
||||
self.action = tf.placeholder(tf.int8, [None], name="action")
|
||||
self.input = self.action
|
||||
# self.action = tf.placeholder(tf.int8, [None], name="action")
|
||||
# self.input = self.action
|
||||
self.output = [
|
||||
self._q_value(input_layer, action)
|
||||
for action in range(self.num_actions)
|
||||
]
|
||||
|
||||
def _q_value(self, input_layer, action):
|
||||
result = tf.py_func(self.DND.query,
|
||||
[input_layer, self.action, self.number_of_nn],
|
||||
[input_layer, [action], self.number_of_nn],
|
||||
[tf.float64, tf.float64])
|
||||
self.dnd_embeddings = tf.to_float(result[0])
|
||||
self.dnd_values = tf.to_float(result[1])
|
||||
dnd_embeddings = tf.to_float(result[0])
|
||||
dnd_values = tf.to_float(result[1])
|
||||
|
||||
# DND calculation
|
||||
square_diff = tf.square(self.dnd_embeddings - tf.expand_dims(input_layer, 1))
|
||||
square_diff = tf.square(dnd_embeddings - tf.expand_dims(input_layer, 1))
|
||||
distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta]
|
||||
weights = 1.0 / distances
|
||||
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True)
|
||||
self.output = tf.reduce_sum(self.dnd_values * normalised_weights, axis=1)
|
||||
return tf.reduce_sum(dnd_values * normalised_weights, axis=1)
|
||||
|
||||
|
||||
class NAFHead(Head):
|
||||
|
||||
Reference in New Issue
Block a user