mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
update nec and value optimization agents to work with recurrent middleware
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@@ -204,10 +204,11 @@ class Agent(object):
|
|||||||
for action in self.env.actions_description:
|
for action in self.env.actions_description:
|
||||||
self.episode_running_info[action] = []
|
self.episode_running_info[action] = []
|
||||||
plt.clf()
|
plt.clf()
|
||||||
|
|
||||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||||
for network in self.networks:
|
for network in self.networks:
|
||||||
network.curr_rnn_c_in = network.middleware_embedder.c_init
|
network.online_network.curr_rnn_c_in = network.online_network.middleware_embedder.c_init
|
||||||
network.curr_rnn_h_in = network.middleware_embedder.h_init
|
network.online_network.curr_rnn_h_in = network.online_network.middleware_embedder.h_init
|
||||||
|
|
||||||
def preprocess_observation(self, observation):
|
def preprocess_observation(self, observation):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from agents.value_optimization_agent import *
|
from agents.value_optimization_agent import *
|
||||||
|
|
||||||
|
|
||||||
@@ -43,26 +45,34 @@ class NECAgent(ValueOptimizationAgent):
|
|||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||||
# convert to batch so we can run it through the network
|
"""
|
||||||
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
this method modifies the superclass's behavior in only 3 ways:
|
||||||
|
|
||||||
|
1) the embedding is saved and stored in self.current_episode_state_embeddings
|
||||||
|
2) the dnd output head is only called if it has a minimum number of entries in it
|
||||||
|
ideally, the dnd had would do this on its own, but in my attempt in encoding this
|
||||||
|
behavior in tensorflow, I ran into problems. Would definitely be worth
|
||||||
|
revisiting in the future
|
||||||
|
3) during training, actions are saved and stored in self.current_episode_actions
|
||||||
|
if behaviors 1 and 2 were handled elsewhere, this could easily be implemented
|
||||||
|
as a wrapper around super instead of overriding this method entirelysearch
|
||||||
|
"""
|
||||||
|
|
||||||
# get embedding
|
# get embedding
|
||||||
embedding = self.main_network.sess.run(self.main_network.online_network.state_embedding,
|
embedding = self.main_network.online_network.predict(
|
||||||
feed_dict={self.main_network.online_network.inputs[0]: observation})
|
self.tf_input_state(curr_state),
|
||||||
self.current_episode_state_embeddings.append(embedding[0])
|
outputs=self.main_network.online_network.state_embedding)
|
||||||
|
self.current_episode_state_embeddings.append(embedding)
|
||||||
|
|
||||||
# get action values
|
# TODO: support additional heads. Right now all other heads are ignored
|
||||||
if self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
|
if self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
|
||||||
# if there are enough entries in the DND then we can query it to get the action values
|
# if there are enough entries in the DND then we can query it to get the action values
|
||||||
actions_q_values = []
|
# actions_q_values = []
|
||||||
for action in range(self.action_space_size):
|
feed_dict = {
|
||||||
feed_dict = {
|
self.main_network.online_network.state_embedding: [embedding],
|
||||||
self.main_network.online_network.state_embedding: embedding,
|
}
|
||||||
self.main_network.online_network.output_heads[0].input[0]: np.array([action])
|
actions_q_values = self.main_network.sess.run(
|
||||||
}
|
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
|
||||||
q_value = self.main_network.sess.run(
|
|
||||||
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
|
|
||||||
actions_q_values.append(q_value[0])
|
|
||||||
else:
|
else:
|
||||||
# get only the embedding so we can insert it to the DND
|
# get only the embedding so we can insert it to the DND
|
||||||
actions_q_values = [0] * self.action_space_size
|
actions_q_values = [0] * self.action_space_size
|
||||||
@@ -70,6 +80,8 @@ class NECAgent(ValueOptimizationAgent):
|
|||||||
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
||||||
if phase == RunPhase.TRAIN:
|
if phase == RunPhase.TRAIN:
|
||||||
action = self.exploration_policy.get_action(actions_q_values)
|
action = self.exploration_policy.get_action(actions_q_values)
|
||||||
|
# NOTE: this next line is not in the parent implementation
|
||||||
|
# NOTE: it could be implemented as a wrapper around the parent since action is returned
|
||||||
self.current_episode_actions.append(action)
|
self.current_episode_actions.append(action)
|
||||||
else:
|
else:
|
||||||
action = np.argmax(actions_q_values)
|
action = np.argmax(actions_q_values)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from agents.agent import *
|
from agents.agent import *
|
||||||
|
|
||||||
|
|
||||||
@@ -30,15 +32,28 @@ class ValueOptimizationAgent(Agent):
|
|||||||
def get_q_values(self, prediction):
|
def get_q_values(self, prediction):
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
def tf_input_state(self, curr_state):
|
||||||
|
"""
|
||||||
|
convert curr_state into input tensors tensorflow is expecting.
|
||||||
|
|
||||||
|
TODO: move this function into Agent and use in as many agent implementations as possible
|
||||||
|
currently, other agents will likely not work with environment measurements.
|
||||||
|
This will become even more important as we support more complex and varied environment states.
|
||||||
|
"""
|
||||||
# convert to batch so we can run it through the network
|
# convert to batch so we can run it through the network
|
||||||
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
||||||
if self.tp.agent.use_measurements:
|
if self.tp.agent.use_measurements:
|
||||||
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
|
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
|
||||||
prediction = self.main_network.online_network.predict([observation, measurements])
|
tf_input_state = [observation, measurements]
|
||||||
else:
|
else:
|
||||||
prediction = self.main_network.online_network.predict(observation)
|
tf_input_state = observation
|
||||||
|
return tf_input_state
|
||||||
|
|
||||||
|
def get_prediction(self, curr_state):
|
||||||
|
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
|
||||||
|
|
||||||
|
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||||
|
prediction = self.get_prediction(curr_state)
|
||||||
actions_q_values = self.get_q_values(prediction)
|
actions_q_values = self.get_q_values(prediction)
|
||||||
|
|
||||||
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -30,16 +30,19 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
class NetworkWrapper(object):
|
class NetworkWrapper(object):
|
||||||
|
"""
|
||||||
|
Contains multiple networks and managers syncing and gradient updates
|
||||||
|
between them.
|
||||||
|
"""
|
||||||
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
|
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
|
||||||
"""
|
"""
|
||||||
|
:param tuning_parameters:
|
||||||
:param tuning_parameters:
|
|
||||||
:type tuning_parameters: Preset
|
:type tuning_parameters: Preset
|
||||||
:param has_target:
|
:param has_target:
|
||||||
:param has_global:
|
:param has_global:
|
||||||
:param name:
|
:param name:
|
||||||
:param replicated_device:
|
:param replicated_device:
|
||||||
:param worker_device:
|
:param worker_device:
|
||||||
"""
|
"""
|
||||||
self.tp = tuning_parameters
|
self.tp = tuning_parameters
|
||||||
self.has_target = has_target
|
self.has_target = has_target
|
||||||
@@ -87,7 +90,7 @@ class NetworkWrapper(object):
|
|||||||
def sync(self):
|
def sync(self):
|
||||||
"""
|
"""
|
||||||
Initializes the weights of the networks to match each other
|
Initializes the weights of the networks to match each other
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.update_online_network()
|
self.update_online_network()
|
||||||
self.update_target_network()
|
self.update_target_network()
|
||||||
@@ -111,14 +114,14 @@ class NetworkWrapper(object):
|
|||||||
def apply_gradients_to_global_network(self):
|
def apply_gradients_to_global_network(self):
|
||||||
"""
|
"""
|
||||||
Apply gradients from the online network on the global network
|
Apply gradients from the online network on the global network
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.global_network.apply_gradients(self.online_network.accumulated_gradients)
|
self.global_network.apply_gradients(self.online_network.accumulated_gradients)
|
||||||
|
|
||||||
def apply_gradients_to_online_network(self):
|
def apply_gradients_to_online_network(self):
|
||||||
"""
|
"""
|
||||||
Apply gradients from the online network on itself
|
Apply gradients from the online network on itself
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
|
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
|
||||||
|
|
||||||
@@ -135,7 +138,7 @@ class NetworkWrapper(object):
|
|||||||
|
|
||||||
def apply_gradients_and_sync_networks(self):
|
def apply_gradients_and_sync_networks(self):
|
||||||
"""
|
"""
|
||||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||||
networks if necessary
|
networks if necessary
|
||||||
"""
|
"""
|
||||||
if self.global_network:
|
if self.global_network:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -64,7 +64,7 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
trainable=False)
|
trainable=False)
|
||||||
self.lock = self.lock_counter.assign_add(1, use_locking=True)
|
self.lock = self.lock_counter.assign_add(1, use_locking=True)
|
||||||
self.lock_init = self.lock_counter.assign(0)
|
self.lock_init = self.lock_counter.assign(0)
|
||||||
|
|
||||||
self.release_counter = tf.get_variable("release_counter", [], tf.int32,
|
self.release_counter = tf.get_variable("release_counter", [], tf.int32,
|
||||||
initializer=tf.constant_initializer(0, dtype=tf.int32),
|
initializer=tf.constant_initializer(0, dtype=tf.int32),
|
||||||
trainable=False)
|
trainable=False)
|
||||||
@@ -86,6 +86,7 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
tuning_parameters.clip_gradients)
|
tuning_parameters.clip_gradients)
|
||||||
|
|
||||||
# gradients of the outputs w.r.t. the inputs
|
# gradients of the outputs w.r.t. the inputs
|
||||||
|
# at the moment, this is only used by ddpg
|
||||||
if len(self.outputs) == 1:
|
if len(self.outputs) == 1:
|
||||||
self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs]
|
self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs]
|
||||||
self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights')
|
self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights')
|
||||||
@@ -126,7 +127,7 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
|
|
||||||
def accumulate_gradients(self, inputs, targets, additional_fetches=None):
|
def accumulate_gradients(self, inputs, targets, additional_fetches=None):
|
||||||
"""
|
"""
|
||||||
Runs a forward pass & backward pass, clips gradients if needed and accumulates them into the accumulation
|
Runs a forward pass & backward pass, clips gradients if needed and accumulates them into the accumulation
|
||||||
placeholders
|
placeholders
|
||||||
:param additional_fetches: Optional tensors to fetch during gradients calculation
|
:param additional_fetches: Optional tensors to fetch during gradients calculation
|
||||||
:param inputs: The input batch for the network
|
:param inputs: The input batch for the network
|
||||||
@@ -164,6 +165,7 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
|
|
||||||
# feed the lstm state if necessary
|
# feed the lstm state if necessary
|
||||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||||
|
# we can't always assume that we are starting from scratch here can we?
|
||||||
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
||||||
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
||||||
|
|
||||||
@@ -231,20 +233,27 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
while self.tp.sess.run(self.release_counter) % self.tp.num_threads != 0:
|
while self.tp.sess.run(self.release_counter) % self.tp.num_threads != 0:
|
||||||
time.sleep(0.00001)
|
time.sleep(0.00001)
|
||||||
|
|
||||||
def predict(self, inputs):
|
def predict(self, inputs, outputs=None):
|
||||||
"""
|
"""
|
||||||
Run a forward pass of the network using the given input
|
Run a forward pass of the network using the given input
|
||||||
:param inputs: The input for the network
|
:param inputs: The input for the network
|
||||||
|
:param outputs: The output for the network, defaults to self.outputs
|
||||||
:return: The network output
|
:return: The network output
|
||||||
|
|
||||||
|
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
feed_dict = dict(zip(self.inputs, force_list(inputs)))
|
feed_dict = dict(zip(self.inputs, force_list(inputs)))
|
||||||
|
if outputs is None:
|
||||||
|
outputs = self.outputs
|
||||||
|
|
||||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||||
feed_dict[self.middleware_embedder.c_in] = self.curr_rnn_c_in
|
feed_dict[self.middleware_embedder.c_in] = self.curr_rnn_c_in
|
||||||
feed_dict[self.middleware_embedder.h_in] = self.curr_rnn_h_in
|
feed_dict[self.middleware_embedder.h_in] = self.curr_rnn_h_in
|
||||||
output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.tp.sess.run([self.outputs, self.middleware_embedder.state_out], feed_dict=feed_dict)
|
|
||||||
|
output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.tp.sess.run([outputs, self.middleware_embedder.state_out], feed_dict=feed_dict)
|
||||||
else:
|
else:
|
||||||
output = self.tp.sess.run(self.outputs, feed_dict)
|
output = self.tp.sess.run(outputs, feed_dict)
|
||||||
|
|
||||||
return squeeze_list(output)
|
return squeeze_list(output)
|
||||||
|
|
||||||
@@ -299,7 +308,7 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
|
|
||||||
def set_variable_value(self, assign_op, value, placeholder=None):
|
def set_variable_value(self, assign_op, value, placeholder=None):
|
||||||
"""
|
"""
|
||||||
Updates the value of a variable.
|
Updates the value of a variable.
|
||||||
This requires having an assign operation for the variable, and a placeholder which will provide the value
|
This requires having an assign operation for the variable, and a placeholder which will provide the value
|
||||||
:param assign_op: an assign operation for the variable
|
:param assign_op: an assign operation for the variable
|
||||||
:param value: a value to set the variable to
|
:param value: a value to set the variable to
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ from configurations import InputTypes, OutputTypes, MiddlewareTypes
|
|||||||
|
|
||||||
|
|
||||||
class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||||
|
"""
|
||||||
|
A generalized version of all possible networks implemented using tensorflow.
|
||||||
|
"""
|
||||||
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
||||||
self.global_network = global_network
|
self.global_network = global_network
|
||||||
self.network_is_local = network_is_local
|
self.network_is_local = network_is_local
|
||||||
@@ -79,7 +82,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
OutputTypes.DNDQ: DNDQHead,
|
OutputTypes.DNDQ: DNDQHead,
|
||||||
OutputTypes.NAF: NAFHead,
|
OutputTypes.NAF: NAFHead,
|
||||||
OutputTypes.PPO: PPOHead,
|
OutputTypes.PPO: PPOHead,
|
||||||
OutputTypes.PPO_V : PPOVHead,
|
OutputTypes.PPO_V: PPOVHead,
|
||||||
OutputTypes.CategoricalQ: CategoricalQHead,
|
OutputTypes.CategoricalQ: CategoricalQHead,
|
||||||
OutputTypes.QuantileRegressionQ: QuantileRegressionQHead
|
OutputTypes.QuantileRegressionQ: QuantileRegressionQHead
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -67,6 +67,10 @@ class Head(object):
|
|||||||
def _build_module(self, input_layer):
|
def _build_module(self, input_layer):
|
||||||
"""
|
"""
|
||||||
Builds the graph of the module
|
Builds the graph of the module
|
||||||
|
|
||||||
|
This method is called early on from __call__. It is expected to store the graph
|
||||||
|
in self.output.
|
||||||
|
|
||||||
:param input_layer: the input to the graph
|
:param input_layer: the input to the graph
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
@@ -279,20 +283,26 @@ class DNDQHead(Head):
|
|||||||
key_error_threshold=self.DND_key_error_threshold)
|
key_error_threshold=self.DND_key_error_threshold)
|
||||||
|
|
||||||
# Retrieve info from DND dictionary
|
# Retrieve info from DND dictionary
|
||||||
self.action = tf.placeholder(tf.int8, [None], name="action")
|
# self.action = tf.placeholder(tf.int8, [None], name="action")
|
||||||
self.input = self.action
|
# self.input = self.action
|
||||||
|
self.output = [
|
||||||
|
self._q_value(input_layer, action)
|
||||||
|
for action in range(self.num_actions)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _q_value(self, input_layer, action):
|
||||||
result = tf.py_func(self.DND.query,
|
result = tf.py_func(self.DND.query,
|
||||||
[input_layer, self.action, self.number_of_nn],
|
[input_layer, [action], self.number_of_nn],
|
||||||
[tf.float64, tf.float64])
|
[tf.float64, tf.float64])
|
||||||
self.dnd_embeddings = tf.to_float(result[0])
|
dnd_embeddings = tf.to_float(result[0])
|
||||||
self.dnd_values = tf.to_float(result[1])
|
dnd_values = tf.to_float(result[1])
|
||||||
|
|
||||||
# DND calculation
|
# DND calculation
|
||||||
square_diff = tf.square(self.dnd_embeddings - tf.expand_dims(input_layer, 1))
|
square_diff = tf.square(dnd_embeddings - tf.expand_dims(input_layer, 1))
|
||||||
distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta]
|
distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta]
|
||||||
weights = 1.0 / distances
|
weights = 1.0 / distances
|
||||||
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True)
|
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True)
|
||||||
self.output = tf.reduce_sum(self.dnd_values * normalised_weights, axis=1)
|
return tf.reduce_sum(dnd_values * normalised_weights, axis=1)
|
||||||
|
|
||||||
|
|
||||||
class NAFHead(Head):
|
class NAFHead(Head):
|
||||||
|
|||||||
@@ -43,8 +43,8 @@ GET_PREFERENCES_MANUALLY=1
|
|||||||
INSTALL_COACH=0
|
INSTALL_COACH=0
|
||||||
INSTALL_DASHBOARD=0
|
INSTALL_DASHBOARD=0
|
||||||
INSTALL_GYM=0
|
INSTALL_GYM=0
|
||||||
INSTALL_VIRTUAL_ENVIRONMENT=1
|
|
||||||
INSTALL_NEON=0
|
INSTALL_NEON=0
|
||||||
|
INSTALL_VIRTUAL_ENVIRONMENT=1
|
||||||
|
|
||||||
# Get user preferences
|
# Get user preferences
|
||||||
TEMP=`getopt -o cpgvrmeNndh \
|
TEMP=`getopt -o cpgvrmeNndh \
|
||||||
@@ -202,4 +202,3 @@ else
|
|||||||
# GPU supported TensorFlow
|
# GPU supported TensorFlow
|
||||||
pip3 install tensorflow-gpu
|
pip3 install tensorflow-gpu
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
15
presets.py
15
presets.py
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -907,6 +907,19 @@ class Doom_Health_DQN(Preset):
|
|||||||
self.agent.num_steps_between_copying_online_weights_to_target = 1000
|
self.agent.num_steps_between_copying_online_weights_to_target = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class Pong_NEC_LSTM(Preset):
|
||||||
|
def __init__(self):
|
||||||
|
Preset.__init__(self, NEC, Atari, ExplorationParameters)
|
||||||
|
self.env.level = 'PongDeterministic-v4'
|
||||||
|
self.learning_rate = 0.001
|
||||||
|
self.agent.num_transitions_in_experience_replay = 1000000
|
||||||
|
self.agent.middleware_type = MiddlewareTypes.LSTM
|
||||||
|
self.exploration.initial_epsilon = 0.5
|
||||||
|
self.exploration.final_epsilon = 0.1
|
||||||
|
self.exploration.epsilon_decay_steps = 1000000
|
||||||
|
self.num_heatup_steps = 500
|
||||||
|
|
||||||
|
|
||||||
class Pong_NEC(Preset):
|
class Pong_NEC(Preset):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Preset.__init__(self, NEC, Atari, ExplorationParameters)
|
Preset.__init__(self, NEC, Atari, ExplorationParameters)
|
||||||
|
|||||||
35
utils.py
35
utils.py
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -180,6 +180,10 @@ def threaded_cmd_line_run(run_cmd, id=-1):
|
|||||||
|
|
||||||
|
|
||||||
class Signal(object):
|
class Signal(object):
|
||||||
|
"""
|
||||||
|
Stores a stream of values and provides methods like get_mean and get_max
|
||||||
|
which returns the statistics about accumulated values.
|
||||||
|
"""
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.sample_count = 0
|
self.sample_count = 0
|
||||||
@@ -190,39 +194,36 @@ class Signal(object):
|
|||||||
self.values = []
|
self.values = []
|
||||||
|
|
||||||
def add_sample(self, sample):
|
def add_sample(self, sample):
|
||||||
|
"""
|
||||||
|
:param sample: either a single value or an array of values
|
||||||
|
"""
|
||||||
self.values.append(sample)
|
self.values.append(sample)
|
||||||
|
|
||||||
|
def _get_values(self):
|
||||||
|
if type(self.values[0]) == np.ndarray:
|
||||||
|
return np.concatenate(self.values)
|
||||||
|
else:
|
||||||
|
return self.values
|
||||||
|
|
||||||
def get_mean(self):
|
def get_mean(self):
|
||||||
if len(self.values) == 0:
|
if len(self.values) == 0:
|
||||||
return ''
|
return ''
|
||||||
if type(self.values[0]) == np.ndarray:
|
return np.mean(self._get_values())
|
||||||
return np.mean(np.concatenate(self.values))
|
|
||||||
else:
|
|
||||||
return np.mean(self.values)
|
|
||||||
|
|
||||||
def get_max(self):
|
def get_max(self):
|
||||||
if len(self.values) == 0:
|
if len(self.values) == 0:
|
||||||
return ''
|
return ''
|
||||||
if type(self.values[0]) == np.ndarray:
|
return np.max(self._get_values())
|
||||||
return np.max(np.concatenate(self.values))
|
|
||||||
else:
|
|
||||||
return np.max(self.values)
|
|
||||||
|
|
||||||
def get_min(self):
|
def get_min(self):
|
||||||
if len(self.values) == 0:
|
if len(self.values) == 0:
|
||||||
return ''
|
return ''
|
||||||
if type(self.values[0]) == np.ndarray:
|
return np.min(self._get_values())
|
||||||
return np.min(np.concatenate(self.values))
|
|
||||||
else:
|
|
||||||
return np.min(self.values)
|
|
||||||
|
|
||||||
def get_stdev(self):
|
def get_stdev(self):
|
||||||
if len(self.values) == 0:
|
if len(self.values) == 0:
|
||||||
return ''
|
return ''
|
||||||
if type(self.values[0]) == np.ndarray:
|
return np.std(self._get_values())
|
||||||
return np.std(np.concatenate(self.values))
|
|
||||||
else:
|
|
||||||
return np.std(self.values)
|
|
||||||
|
|
||||||
|
|
||||||
def force_list(var):
|
def force_list(var):
|
||||||
|
|||||||
Reference in New Issue
Block a user