diff --git a/agents/actor_critic_agent.py b/agents/actor_critic_agent.py index ed35ee6..5d628b6 100644 --- a/agents/actor_critic_agent.py +++ b/agents/actor_critic_agent.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/agents/agent.py b/agents/agent.py index bd34d16..81ea924 100644 --- a/agents/agent.py +++ b/agents/agent.py @@ -204,10 +204,11 @@ class Agent(object): for action in self.env.actions_description: self.episode_running_info[action] = [] plt.clf() + if self.tp.agent.middleware_type == MiddlewareTypes.LSTM: for network in self.networks: - network.curr_rnn_c_in = network.middleware_embedder.c_init - network.curr_rnn_h_in = network.middleware_embedder.h_init + network.online_network.curr_rnn_c_in = network.online_network.middleware_embedder.c_init + network.online_network.curr_rnn_h_in = network.online_network.middleware_embedder.h_init def preprocess_observation(self, observation): """ diff --git a/agents/nec_agent.py b/agents/nec_agent.py index e8ac535..85b2855 100644 --- a/agents/nec_agent.py +++ b/agents/nec_agent.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ # limitations under the License. # +import numpy as np + from agents.value_optimization_agent import * @@ -43,26 +45,34 @@ class NECAgent(ValueOptimizationAgent): return total_loss def choose_action(self, curr_state, phase=RunPhase.TRAIN): - # convert to batch so we can run it through the network - observation = np.expand_dims(np.array(curr_state['observation']), 0) + """ + this method modifies the superclass's behavior in only 3 ways: + + 1) the embedding is saved and stored in self.current_episode_state_embeddings + 2) the dnd output head is only called if it has a minimum number of entries in it + ideally, the dnd had would do this on its own, but in my attempt in encoding this + behavior in tensorflow, I ran into problems. Would definitely be worth + revisiting in the future + 3) during training, actions are saved and stored in self.current_episode_actions + if behaviors 1 and 2 were handled elsewhere, this could easily be implemented + as a wrapper around super instead of overriding this method entirelysearch + """ # get embedding - embedding = self.main_network.sess.run(self.main_network.online_network.state_embedding, - feed_dict={self.main_network.online_network.inputs[0]: observation}) - self.current_episode_state_embeddings.append(embedding[0]) + embedding = self.main_network.online_network.predict( + self.tf_input_state(curr_state), + outputs=self.main_network.online_network.state_embedding) + self.current_episode_state_embeddings.append(embedding) - # get action values + # TODO: support additional heads. Right now all other heads are ignored if self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn): # if there are enough entries in the DND then we can query it to get the action values - actions_q_values = [] - for action in range(self.action_space_size): - feed_dict = { - self.main_network.online_network.state_embedding: embedding, - self.main_network.online_network.output_heads[0].input[0]: np.array([action]) - } - q_value = self.main_network.sess.run( - self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict) - actions_q_values.append(q_value[0]) + # actions_q_values = [] + feed_dict = { + self.main_network.online_network.state_embedding: [embedding], + } + actions_q_values = self.main_network.sess.run( + self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict) else: # get only the embedding so we can insert it to the DND actions_q_values = [0] * self.action_space_size @@ -70,6 +80,8 @@ class NECAgent(ValueOptimizationAgent): # choose action according to the exploration policy and the current phase (evaluating or training the agent) if phase == RunPhase.TRAIN: action = self.exploration_policy.get_action(actions_q_values) + # NOTE: this next line is not in the parent implementation + # NOTE: it could be implemented as a wrapper around the parent since action is returned self.current_episode_actions.append(action) else: action = np.argmax(actions_q_values) diff --git a/agents/value_optimization_agent.py b/agents/value_optimization_agent.py index f348333..0684e34 100644 --- a/agents/value_optimization_agent.py +++ b/agents/value_optimization_agent.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ # limitations under the License. # +import numpy as np + from agents.agent import * @@ -30,15 +32,28 @@ class ValueOptimizationAgent(Agent): def get_q_values(self, prediction): return prediction - def choose_action(self, curr_state, phase=RunPhase.TRAIN): + def tf_input_state(self, curr_state): + """ + convert curr_state into input tensors tensorflow is expecting. + + TODO: move this function into Agent and use in as many agent implementations as possible + currently, other agents will likely not work with environment measurements. + This will become even more important as we support more complex and varied environment states. + """ # convert to batch so we can run it through the network observation = np.expand_dims(np.array(curr_state['observation']), 0) if self.tp.agent.use_measurements: measurements = np.expand_dims(np.array(curr_state['measurements']), 0) - prediction = self.main_network.online_network.predict([observation, measurements]) + tf_input_state = [observation, measurements] else: - prediction = self.main_network.online_network.predict(observation) + tf_input_state = observation + return tf_input_state + def get_prediction(self, curr_state): + return self.main_network.online_network.predict(self.tf_input_state(curr_state)) + + def choose_action(self, curr_state, phase=RunPhase.TRAIN): + prediction = self.get_prediction(curr_state) actions_q_values = self.get_q_values(prediction) # choose action according to the exploration policy and the current phase (evaluating or training the agent) diff --git a/architectures/neon_components/general_network.py b/architectures/neon_components/general_network.py index 4bae454..8bc9f7d 100644 --- a/architectures/neon_components/general_network.py +++ b/architectures/neon_components/general_network.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/architectures/network_wrapper.py b/architectures/network_wrapper.py index bbe6c59..0867d8db 100644 --- a/architectures/network_wrapper.py +++ b/architectures/network_wrapper.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,16 +30,19 @@ except ImportError: class NetworkWrapper(object): + """ + Contains multiple networks and managers syncing and gradient updates + between them. + """ def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None): """ - - :param tuning_parameters: + :param tuning_parameters: :type tuning_parameters: Preset - :param has_target: - :param has_global: - :param name: - :param replicated_device: - :param worker_device: + :param has_target: + :param has_global: + :param name: + :param replicated_device: + :param worker_device: """ self.tp = tuning_parameters self.has_target = has_target @@ -87,7 +90,7 @@ class NetworkWrapper(object): def sync(self): """ Initializes the weights of the networks to match each other - :return: + :return: """ self.update_online_network() self.update_target_network() @@ -111,14 +114,14 @@ class NetworkWrapper(object): def apply_gradients_to_global_network(self): """ Apply gradients from the online network on the global network - :return: + :return: """ self.global_network.apply_gradients(self.online_network.accumulated_gradients) def apply_gradients_to_online_network(self): """ Apply gradients from the online network on itself - :return: + :return: """ self.online_network.apply_gradients(self.online_network.accumulated_gradients) @@ -135,7 +138,7 @@ class NetworkWrapper(object): def apply_gradients_and_sync_networks(self): """ - Applies the gradients accumulated in the online network to the global network or to itself and syncs the + Applies the gradients accumulated in the online network to the global network or to itself and syncs the networks if necessary """ if self.global_network: diff --git a/architectures/tensorflow_components/architecture.py b/architectures/tensorflow_components/architecture.py index 276a122..dc86049 100644 --- a/architectures/tensorflow_components/architecture.py +++ b/architectures/tensorflow_components/architecture.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -64,7 +64,7 @@ class TensorFlowArchitecture(Architecture): trainable=False) self.lock = self.lock_counter.assign_add(1, use_locking=True) self.lock_init = self.lock_counter.assign(0) - + self.release_counter = tf.get_variable("release_counter", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) @@ -86,6 +86,7 @@ class TensorFlowArchitecture(Architecture): tuning_parameters.clip_gradients) # gradients of the outputs w.r.t. the inputs + # at the moment, this is only used by ddpg if len(self.outputs) == 1: self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs] self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights') @@ -126,7 +127,7 @@ class TensorFlowArchitecture(Architecture): def accumulate_gradients(self, inputs, targets, additional_fetches=None): """ - Runs a forward pass & backward pass, clips gradients if needed and accumulates them into the accumulation + Runs a forward pass & backward pass, clips gradients if needed and accumulates them into the accumulation placeholders :param additional_fetches: Optional tensors to fetch during gradients calculation :param inputs: The input batch for the network @@ -164,6 +165,7 @@ class TensorFlowArchitecture(Architecture): # feed the lstm state if necessary if self.tp.agent.middleware_type == MiddlewareTypes.LSTM: + # we can't always assume that we are starting from scratch here can we? feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init @@ -231,20 +233,27 @@ class TensorFlowArchitecture(Architecture): while self.tp.sess.run(self.release_counter) % self.tp.num_threads != 0: time.sleep(0.00001) - def predict(self, inputs): + def predict(self, inputs, outputs=None): """ Run a forward pass of the network using the given input :param inputs: The input for the network + :param outputs: The output for the network, defaults to self.outputs :return: The network output + + WARNING: must only call once per state since each call is assumed by LSTM to be a new time step. """ feed_dict = dict(zip(self.inputs, force_list(inputs))) + if outputs is None: + outputs = self.outputs + if self.tp.agent.middleware_type == MiddlewareTypes.LSTM: feed_dict[self.middleware_embedder.c_in] = self.curr_rnn_c_in feed_dict[self.middleware_embedder.h_in] = self.curr_rnn_h_in - output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.tp.sess.run([self.outputs, self.middleware_embedder.state_out], feed_dict=feed_dict) + + output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.tp.sess.run([outputs, self.middleware_embedder.state_out], feed_dict=feed_dict) else: - output = self.tp.sess.run(self.outputs, feed_dict) + output = self.tp.sess.run(outputs, feed_dict) return squeeze_list(output) @@ -299,7 +308,7 @@ class TensorFlowArchitecture(Architecture): def set_variable_value(self, assign_op, value, placeholder=None): """ - Updates the value of a variable. + Updates the value of a variable. This requires having an assign operation for the variable, and a placeholder which will provide the value :param assign_op: an assign operation for the variable :param value: a value to set the variable to diff --git a/architectures/tensorflow_components/general_network.py b/architectures/tensorflow_components/general_network.py index 842ad66..5258976 100644 --- a/architectures/tensorflow_components/general_network.py +++ b/architectures/tensorflow_components/general_network.py @@ -22,6 +22,9 @@ from configurations import InputTypes, OutputTypes, MiddlewareTypes class GeneralTensorFlowNetwork(TensorFlowArchitecture): + """ + A generalized version of all possible networks implemented using tensorflow. + """ def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True): self.global_network = global_network self.network_is_local = network_is_local @@ -79,7 +82,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): OutputTypes.DNDQ: DNDQHead, OutputTypes.NAF: NAFHead, OutputTypes.PPO: PPOHead, - OutputTypes.PPO_V : PPOVHead, + OutputTypes.PPO_V: PPOVHead, OutputTypes.CategoricalQ: CategoricalQHead, OutputTypes.QuantileRegressionQ: QuantileRegressionQHead } diff --git a/architectures/tensorflow_components/heads.py b/architectures/tensorflow_components/heads.py index 91fd192..b2d8e49 100644 --- a/architectures/tensorflow_components/heads.py +++ b/architectures/tensorflow_components/heads.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -67,6 +67,10 @@ class Head(object): def _build_module(self, input_layer): """ Builds the graph of the module + + This method is called early on from __call__. It is expected to store the graph + in self.output. + :param input_layer: the input to the graph :return: None """ @@ -279,20 +283,26 @@ class DNDQHead(Head): key_error_threshold=self.DND_key_error_threshold) # Retrieve info from DND dictionary - self.action = tf.placeholder(tf.int8, [None], name="action") - self.input = self.action + # self.action = tf.placeholder(tf.int8, [None], name="action") + # self.input = self.action + self.output = [ + self._q_value(input_layer, action) + for action in range(self.num_actions) + ] + + def _q_value(self, input_layer, action): result = tf.py_func(self.DND.query, - [input_layer, self.action, self.number_of_nn], + [input_layer, [action], self.number_of_nn], [tf.float64, tf.float64]) - self.dnd_embeddings = tf.to_float(result[0]) - self.dnd_values = tf.to_float(result[1]) + dnd_embeddings = tf.to_float(result[0]) + dnd_values = tf.to_float(result[1]) # DND calculation - square_diff = tf.square(self.dnd_embeddings - tf.expand_dims(input_layer, 1)) + square_diff = tf.square(dnd_embeddings - tf.expand_dims(input_layer, 1)) distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta] weights = 1.0 / distances normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True) - self.output = tf.reduce_sum(self.dnd_values * normalised_weights, axis=1) + return tf.reduce_sum(dnd_values * normalised_weights, axis=1) class NAFHead(Head): diff --git a/install.sh b/install.sh index 4f2a775..8563f05 100755 --- a/install.sh +++ b/install.sh @@ -43,8 +43,8 @@ GET_PREFERENCES_MANUALLY=1 INSTALL_COACH=0 INSTALL_DASHBOARD=0 INSTALL_GYM=0 -INSTALL_VIRTUAL_ENVIRONMENT=1 INSTALL_NEON=0 +INSTALL_VIRTUAL_ENVIRONMENT=1 # Get user preferences TEMP=`getopt -o cpgvrmeNndh \ @@ -202,4 +202,3 @@ else # GPU supported TensorFlow pip3 install tensorflow-gpu fi - diff --git a/presets.py b/presets.py index dc3b5ec..c0225ad 100644 --- a/presets.py +++ b/presets.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -907,6 +907,19 @@ class Doom_Health_DQN(Preset): self.agent.num_steps_between_copying_online_weights_to_target = 1000 +class Pong_NEC_LSTM(Preset): + def __init__(self): + Preset.__init__(self, NEC, Atari, ExplorationParameters) + self.env.level = 'PongDeterministic-v4' + self.learning_rate = 0.001 + self.agent.num_transitions_in_experience_replay = 1000000 + self.agent.middleware_type = MiddlewareTypes.LSTM + self.exploration.initial_epsilon = 0.5 + self.exploration.final_epsilon = 0.1 + self.exploration.epsilon_decay_steps = 1000000 + self.num_heatup_steps = 500 + + class Pong_NEC(Preset): def __init__(self): Preset.__init__(self, NEC, Atari, ExplorationParameters) diff --git a/utils.py b/utils.py index db97994..d598722 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -180,6 +180,10 @@ def threaded_cmd_line_run(run_cmd, id=-1): class Signal(object): + """ + Stores a stream of values and provides methods like get_mean and get_max + which returns the statistics about accumulated values. + """ def __init__(self, name): self.name = name self.sample_count = 0 @@ -190,39 +194,36 @@ class Signal(object): self.values = [] def add_sample(self, sample): + """ + :param sample: either a single value or an array of values + """ self.values.append(sample) + def _get_values(self): + if type(self.values[0]) == np.ndarray: + return np.concatenate(self.values) + else: + return self.values + def get_mean(self): if len(self.values) == 0: return '' - if type(self.values[0]) == np.ndarray: - return np.mean(np.concatenate(self.values)) - else: - return np.mean(self.values) + return np.mean(self._get_values()) def get_max(self): if len(self.values) == 0: return '' - if type(self.values[0]) == np.ndarray: - return np.max(np.concatenate(self.values)) - else: - return np.max(self.values) + return np.max(self._get_values()) def get_min(self): if len(self.values) == 0: return '' - if type(self.values[0]) == np.ndarray: - return np.min(np.concatenate(self.values)) - else: - return np.min(self.values) + return np.min(self._get_values()) def get_stdev(self): if len(self.values) == 0: return '' - if type(self.values[0]) == np.ndarray: - return np.std(np.concatenate(self.values)) - else: - return np.std(self.values) + return np.std(self._get_values()) def force_list(var):