1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

update nec and value optimization agents to work with recurrent middleware

This commit is contained in:
Zach Dwiel
2017-11-03 13:58:42 -07:00
parent 93a54c7e8e
commit 6c79a442f2
12 changed files with 138 additions and 72 deletions

View File

@@ -204,10 +204,11 @@ class Agent(object):
for action in self.env.actions_description:
self.episode_running_info[action] = []
plt.clf()
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
for network in self.networks:
network.curr_rnn_c_in = network.middleware_embedder.c_init
network.curr_rnn_h_in = network.middleware_embedder.h_init
network.online_network.curr_rnn_c_in = network.online_network.middleware_embedder.c_init
network.online_network.curr_rnn_h_in = network.online_network.middleware_embedder.h_init
def preprocess_observation(self, observation):
"""

View File

@@ -14,6 +14,8 @@
# limitations under the License.
#
import numpy as np
from agents.value_optimization_agent import *
@@ -43,26 +45,34 @@ class NECAgent(ValueOptimizationAgent):
return total_loss
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
"""
this method modifies the superclass's behavior in only 3 ways:
1) the embedding is saved and stored in self.current_episode_state_embeddings
2) the dnd output head is only called if it has a minimum number of entries in it
ideally, the dnd had would do this on its own, but in my attempt in encoding this
behavior in tensorflow, I ran into problems. Would definitely be worth
revisiting in the future
3) during training, actions are saved and stored in self.current_episode_actions
if behaviors 1 and 2 were handled elsewhere, this could easily be implemented
as a wrapper around super instead of overriding this method entirelysearch
"""
# get embedding
embedding = self.main_network.sess.run(self.main_network.online_network.state_embedding,
feed_dict={self.main_network.online_network.inputs[0]: observation})
self.current_episode_state_embeddings.append(embedding[0])
embedding = self.main_network.online_network.predict(
self.tf_input_state(curr_state),
outputs=self.main_network.online_network.state_embedding)
self.current_episode_state_embeddings.append(embedding)
# get action values
# TODO: support additional heads. Right now all other heads are ignored
if self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
# if there are enough entries in the DND then we can query it to get the action values
actions_q_values = []
for action in range(self.action_space_size):
feed_dict = {
self.main_network.online_network.state_embedding: embedding,
self.main_network.online_network.output_heads[0].input[0]: np.array([action])
}
q_value = self.main_network.sess.run(
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
actions_q_values.append(q_value[0])
# actions_q_values = []
feed_dict = {
self.main_network.online_network.state_embedding: [embedding],
}
actions_q_values = self.main_network.sess.run(
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
else:
# get only the embedding so we can insert it to the DND
actions_q_values = [0] * self.action_space_size
@@ -70,6 +80,8 @@ class NECAgent(ValueOptimizationAgent):
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
if phase == RunPhase.TRAIN:
action = self.exploration_policy.get_action(actions_q_values)
# NOTE: this next line is not in the parent implementation
# NOTE: it could be implemented as a wrapper around the parent since action is returned
self.current_episode_actions.append(action)
else:
action = np.argmax(actions_q_values)

View File

@@ -14,6 +14,8 @@
# limitations under the License.
#
import numpy as np
from agents.agent import *
@@ -30,15 +32,28 @@ class ValueOptimizationAgent(Agent):
def get_q_values(self, prediction):
return prediction
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
def tf_input_state(self, curr_state):
"""
convert curr_state into input tensors tensorflow is expecting.
TODO: move this function into Agent and use in as many agent implementations as possible
currently, other agents will likely not work with environment measurements.
This will become even more important as we support more complex and varied environment states.
"""
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
if self.tp.agent.use_measurements:
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
prediction = self.main_network.online_network.predict([observation, measurements])
tf_input_state = [observation, measurements]
else:
prediction = self.main_network.online_network.predict(observation)
tf_input_state = observation
return tf_input_state
def get_prediction(self, curr_state):
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
prediction = self.get_prediction(curr_state)
actions_q_values = self.get_q_values(prediction)
# choose action according to the exploration policy and the current phase (evaluating or training the agent)

View File

@@ -30,9 +30,12 @@ except ImportError:
class NetworkWrapper(object):
"""
Contains multiple networks and managers syncing and gradient updates
between them.
"""
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
"""
:param tuning_parameters:
:type tuning_parameters: Preset
:param has_target:

View File

@@ -86,6 +86,7 @@ class TensorFlowArchitecture(Architecture):
tuning_parameters.clip_gradients)
# gradients of the outputs w.r.t. the inputs
# at the moment, this is only used by ddpg
if len(self.outputs) == 1:
self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs]
self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights')
@@ -164,6 +165,7 @@ class TensorFlowArchitecture(Architecture):
# feed the lstm state if necessary
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
# we can't always assume that we are starting from scratch here can we?
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
@@ -231,20 +233,27 @@ class TensorFlowArchitecture(Architecture):
while self.tp.sess.run(self.release_counter) % self.tp.num_threads != 0:
time.sleep(0.00001)
def predict(self, inputs):
def predict(self, inputs, outputs=None):
"""
Run a forward pass of the network using the given input
:param inputs: The input for the network
:param outputs: The output for the network, defaults to self.outputs
:return: The network output
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
"""
feed_dict = dict(zip(self.inputs, force_list(inputs)))
if outputs is None:
outputs = self.outputs
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
feed_dict[self.middleware_embedder.c_in] = self.curr_rnn_c_in
feed_dict[self.middleware_embedder.h_in] = self.curr_rnn_h_in
output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.tp.sess.run([self.outputs, self.middleware_embedder.state_out], feed_dict=feed_dict)
output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.tp.sess.run([outputs, self.middleware_embedder.state_out], feed_dict=feed_dict)
else:
output = self.tp.sess.run(self.outputs, feed_dict)
output = self.tp.sess.run(outputs, feed_dict)
return squeeze_list(output)

View File

@@ -22,6 +22,9 @@ from configurations import InputTypes, OutputTypes, MiddlewareTypes
class GeneralTensorFlowNetwork(TensorFlowArchitecture):
"""
A generalized version of all possible networks implemented using tensorflow.
"""
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
self.global_network = global_network
self.network_is_local = network_is_local
@@ -79,7 +82,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
OutputTypes.DNDQ: DNDQHead,
OutputTypes.NAF: NAFHead,
OutputTypes.PPO: PPOHead,
OutputTypes.PPO_V : PPOVHead,
OutputTypes.PPO_V: PPOVHead,
OutputTypes.CategoricalQ: CategoricalQHead,
OutputTypes.QuantileRegressionQ: QuantileRegressionQHead
}

View File

@@ -67,6 +67,10 @@ class Head(object):
def _build_module(self, input_layer):
"""
Builds the graph of the module
This method is called early on from __call__. It is expected to store the graph
in self.output.
:param input_layer: the input to the graph
:return: None
"""
@@ -279,20 +283,26 @@ class DNDQHead(Head):
key_error_threshold=self.DND_key_error_threshold)
# Retrieve info from DND dictionary
self.action = tf.placeholder(tf.int8, [None], name="action")
self.input = self.action
# self.action = tf.placeholder(tf.int8, [None], name="action")
# self.input = self.action
self.output = [
self._q_value(input_layer, action)
for action in range(self.num_actions)
]
def _q_value(self, input_layer, action):
result = tf.py_func(self.DND.query,
[input_layer, self.action, self.number_of_nn],
[input_layer, [action], self.number_of_nn],
[tf.float64, tf.float64])
self.dnd_embeddings = tf.to_float(result[0])
self.dnd_values = tf.to_float(result[1])
dnd_embeddings = tf.to_float(result[0])
dnd_values = tf.to_float(result[1])
# DND calculation
square_diff = tf.square(self.dnd_embeddings - tf.expand_dims(input_layer, 1))
square_diff = tf.square(dnd_embeddings - tf.expand_dims(input_layer, 1))
distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta]
weights = 1.0 / distances
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True)
self.output = tf.reduce_sum(self.dnd_values * normalised_weights, axis=1)
return tf.reduce_sum(dnd_values * normalised_weights, axis=1)
class NAFHead(Head):

View File

@@ -43,8 +43,8 @@ GET_PREFERENCES_MANUALLY=1
INSTALL_COACH=0
INSTALL_DASHBOARD=0
INSTALL_GYM=0
INSTALL_VIRTUAL_ENVIRONMENT=1
INSTALL_NEON=0
INSTALL_VIRTUAL_ENVIRONMENT=1
# Get user preferences
TEMP=`getopt -o cpgvrmeNndh \
@@ -202,4 +202,3 @@ else
# GPU supported TensorFlow
pip3 install tensorflow-gpu
fi

View File

@@ -907,6 +907,19 @@ class Doom_Health_DQN(Preset):
self.agent.num_steps_between_copying_online_weights_to_target = 1000
class Pong_NEC_LSTM(Preset):
def __init__(self):
Preset.__init__(self, NEC, Atari, ExplorationParameters)
self.env.level = 'PongDeterministic-v4'
self.learning_rate = 0.001
self.agent.num_transitions_in_experience_replay = 1000000
self.agent.middleware_type = MiddlewareTypes.LSTM
self.exploration.initial_epsilon = 0.5
self.exploration.final_epsilon = 0.1
self.exploration.epsilon_decay_steps = 1000000
self.num_heatup_steps = 500
class Pong_NEC(Preset):
def __init__(self):
Preset.__init__(self, NEC, Atari, ExplorationParameters)

View File

@@ -180,6 +180,10 @@ def threaded_cmd_line_run(run_cmd, id=-1):
class Signal(object):
"""
Stores a stream of values and provides methods like get_mean and get_max
which returns the statistics about accumulated values.
"""
def __init__(self, name):
self.name = name
self.sample_count = 0
@@ -190,39 +194,36 @@ class Signal(object):
self.values = []
def add_sample(self, sample):
"""
:param sample: either a single value or an array of values
"""
self.values.append(sample)
def _get_values(self):
if type(self.values[0]) == np.ndarray:
return np.concatenate(self.values)
else:
return self.values
def get_mean(self):
if len(self.values) == 0:
return ''
if type(self.values[0]) == np.ndarray:
return np.mean(np.concatenate(self.values))
else:
return np.mean(self.values)
return np.mean(self._get_values())
def get_max(self):
if len(self.values) == 0:
return ''
if type(self.values[0]) == np.ndarray:
return np.max(np.concatenate(self.values))
else:
return np.max(self.values)
return np.max(self._get_values())
def get_min(self):
if len(self.values) == 0:
return ''
if type(self.values[0]) == np.ndarray:
return np.min(np.concatenate(self.values))
else:
return np.min(self.values)
return np.min(self._get_values())
def get_stdev(self):
if len(self.values) == 0:
return ''
if type(self.values[0]) == np.ndarray:
return np.std(np.concatenate(self.values))
else:
return np.std(self.values)
return np.std(self._get_values())
def force_list(var):