diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..3e3aa5d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,16 @@
+.idea
+experiments
+*.pyc
+checkpoints
+_vizdoom.ini
+*.*~
+MUJOCO_LOG.TXT
+test_log.txt
+.test
+tf_logs
+bullet3
+roboschool
+*.csv
+*.doc
+*.orig
+docs/site
diff --git a/LICENSE b/LICENSE
deleted file mode 100644
index 8dada3e..0000000
--- a/LICENSE
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "{}"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright {yyyy} {name of copyright owner}
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/README.md b/README.md
index a231240..93c3103 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,211 @@
-# coach
-Reinforcement Learning Coach by Intel® Nervana™ AI enables easy experimentation with state of the art Reinforcement Learning algorithms
+# Coach
+
+## Overview
+
+Coach is a python reinforcement learning research framework containing implementation of many state-of-the-art algorithms.
+
+It exposes a set of easy-to-use APIs for experimenting with new RL algorithms, and allows simple integration of new environments to solve.
+Basic RL components (algorithms, environments, neural network architectures, exploration policies, ...) are well decoupled, so that extending and reusing existing components is fairly painless.
+
+Training an agent to solve an environment is as easy as running:
+
+```bash
+python coach.py -p CartPole_DQN -r
+```
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## Installation
+
+Note: Coach has been tested on Ubuntu 16.04 LTS only.
+
+Coach's installer will setup all the basics needed to get the user going with running Coach on top of [OpenAI Gym](https://github.com/openai/gym) environments. This can be done by running the following command and then following the on-screen printed instructions:
+
+```bash
+./install.sh
+```
+
+Coach creates a virtual environment and installs in it to avoid changes to the user's system.
+
+In order to activate and deactivate Coach's virtual environment:
+
+```bash
+source coach_env/bin/activate
+```
+
+```bash
+deactivate
+```
+
+In addition to OpenAI Gym, several other environments were tested and are supported. Please follow the instructions in the Supported Environments section below in order to install more environments.
+
+### GPU Support
+
+####TensorFlow
+
+Coach's installer installs [Intel-Optimized TensorFlow](https://software.intel.com/en-us/articles/intel-optimized-tensorflow-wheel-now-available), which does not support GPU, by default. In order to have Coach running with GPU, a GPU supported TensorFlow version must be installed. This can be done by overriding the TensorFlow version:
+
+```bash
+pip install tensorflow-gpu
+```
+
+## Running Coach
+
+Coach supports both TensorFlow and neon deep learning frameworks.
+
+Switching between TensorFlow and neon backends is possible by using the `-f` flag.
+
+Using TensorFlow (default): `-f tensorflow`
+
+Using neon: `-f neon`
+
+There are several available presets in presets.py.
+To list all the available presets use the `-l` flag.
+
+To run a preset, use:
+
+```bash
+python coach.py -r -p
+
+
+
+## Parallelizing an Algorithm
+
+Since the introduction of [A3C](https://arxiv.org/abs/1602.01783) in 2016, many algorithms were shown to benefit from running multiple instances in parallel, on many CPU cores. So far, these algorithms include [A3C](https://arxiv.org/abs/1602.01783), [DDPG](https://arxiv.org/pdf/1704.03073.pdf), [PPO](https://arxiv.org/abs/1707.02286), and [NAF](https://arxiv.org/pdf/1610.00633.pdf), and this is most probably only the begining.
+
+Parallelizing an algorithm using Coach is straight-forward.
+
+The following method of NetworkWrapper parallelizes an algorithm seamlessly:
+
+```python
+network.train_and_sync_networks(current_states, targets)
+```
+
+Once a parallelized run is started, the ```train_and_sync_networks``` API will apply gradients from each local worker's network to the main global network, allowing for parallel training to take place.
+
+Then, it merely requires running Coach with the ``` -n``` flag and with the number of workers to run with. For instance, the following command will set 16 workers to work together to train a MuJoCo Hopper:
+
+```bash
+python coach.py -p Hopper_A3C -n 16
+```
+
+
+
+## Supported Environments
+
+* OpenAI Gym
+
+ Installed by default by Coach's installer.
+
+* ViZDoom:
+
+ Follow the instructions described in the ViZDoom repository -
+
+ https://github.com/mwydmuch/ViZDoom
+
+ Additionally, Coach assumes that the environment variable VIZDOOM_ROOT points to the ViZDoom installation directory.
+
+* Roboschool:
+
+ Follow the instructions described in the roboschool repository -
+
+ https://github.com/openai/roboschool
+
+* GymExtensions:
+
+ Follow the instructions described in the GymExtensions repository -
+
+ https://github.com/Breakend/gym-extensions
+
+ Additionally, add the installation directory to the PYTHONPATH environment variable.
+
+* PyBullet
+
+ Follow the instructions described in the [Quick Start Guide](https://docs.google.com/document/d/10sXEhzFRSnvFcl3XxNGhnD4N2SedqwdAvK3dsihxVUA) (basically just - 'pip install pybullet')
+
+
+
+## Supported Algorithms
+
+
+
+
+
+
+
+* [Deep Q Network (DQN](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf)
+* [Double Deep Q Network (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
+* [Dueling Q Network](https://arxiv.org/abs/1511.06581)
+* [Mixed Monte Carlo (MMC)](https://arxiv.org/abs/1707.06887)
+* [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860)
+* [Distributional Deep Q Network ](https://arxiv.org/abs/1707.06887)
+* [Bootstrapped Deep Q Network](https://arxiv.org/abs/1602.04621)
+* [N-Step Q Learning](https://arxiv.org/abs/1602.01783) | **Distributed**
+* [Neural Episodic Control (NEC) ](https://arxiv.org/abs/1703.01988)
+* [Normalized Advantage Functions (NAF)](https://arxiv.org/abs/1603.00748.pdf) | **Distributed**
+* [Policy Gradients (PG)](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) | **Distributed**
+* [Actor Critic / A3C](https://arxiv.org/abs/1602.01783) | **Distributed**
+* [Deep Deterministic Policy Gradients (DDPG)](https://arxiv.org/abs/1509.02971) | **Distributed**
+* [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.02286.pdf)
+* [Clipped Proximal Policy Optimization](https://arxiv.org/pdf/1707.06347.pdf) | **Distributed**
+* [Direct Future Prediction (DFP)](https://arxiv.org/abs/1611.01779) | **Distributed**
+
+
+
+
+## Disclaimer
+
+Coach is released as a reference code for research purposes. It is not an official Intel product, and the level of quality and support may not be as expected from an official product.
+Additional algorithms and environments are planned to be added to the framework. Feedback and contributions from the open source and RL research communities are more than welcome.
+
diff --git a/agents/__init__.py b/agents/__init__.py
new file mode 100644
index 0000000..c8d342a
--- /dev/null
+++ b/agents/__init__.py
@@ -0,0 +1,34 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.actor_critic_agent import *
+from agents.agent import *
+from agents.bootstrapped_dqn_agent import *
+from agents.clipped_ppo_agent import *
+from agents.ddpg_agent import *
+from agents.ddqn_agent import *
+from agents.dfp_agent import *
+from agents.dqn_agent import *
+from agents.distributional_dqn_agent import *
+from agents.mmc_agent import *
+from agents.n_step_q_agent import *
+from agents.naf_agent import *
+from agents.nec_agent import *
+from agents.pal_agent import *
+from agents.policy_gradients_agent import *
+from agents.policy_optimization_agent import *
+from agents.ppo_agent import *
+from agents.value_optimization_agent import *
diff --git a/agents/actor_critic_agent.py b/agents/actor_critic_agent.py
new file mode 100644
index 0000000..279075f
--- /dev/null
+++ b/agents/actor_critic_agent.py
@@ -0,0 +1,136 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.policy_optimization_agent import *
+from logger import *
+from utils import *
+import scipy.signal
+
+
+# Actor Critic - https://arxiv.org/abs/1602.01783
+class ActorCriticAgent(PolicyOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0, create_target_network = False):
+ PolicyOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id, create_target_network)
+ self.last_gradient_update_step_idx = 0
+ self.action_advantages = Signal('Advantages')
+ self.state_values = Signal('Values')
+ self.unclipped_grads = Signal('Grads (unclipped)')
+ self.signals.append(self.action_advantages)
+ self.signals.append(self.state_values)
+ self.signals.append(self.unclipped_grads)
+
+ # Discounting function used to calculate discounted returns.
+ def discount(self, x, gamma):
+ return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
+
+ def get_general_advantage_estimation_values(self, rewards, values):
+ # values contain n+1 elements (t ... t+n+1), rewards contain n elements (t ... t + n)
+ bootstrap_extended_rewards = np.array(rewards.tolist() + [values[-1]])
+
+ # Approximation based calculation of GAE (mathematically correct only when Tmax = inf,
+ # although in practice works even in much smaller Tmax values, e.g. 20)
+ deltas = rewards + self.tp.agent.discount * values[1:] - values[:-1]
+ gae = self.discount(deltas, self.tp.agent.discount * self.tp.agent.gae_lambda)
+
+ if self.tp.agent.estimate_value_using_gae:
+ discounted_returns = np.expand_dims(gae + values[:-1], -1)
+ else:
+ discounted_returns = np.expand_dims(np.array(self.discount(bootstrap_extended_rewards,
+ self.tp.agent.discount)), 1)[:-1]
+ return gae, discounted_returns
+
+ def learn_from_batch(self, batch):
+ # batch contains a list of episodes to learn from
+ current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
+
+ # get the values for the current states
+ result = self.main_network.online_network.predict(current_states)
+ current_state_values = result[0]
+ self.state_values.add_sample(current_state_values)
+
+ # the targets for the state value estimator
+ num_transitions = len(game_overs)
+ state_value_head_targets = np.zeros((num_transitions, 1))
+
+ # estimate the advantage function
+ action_advantages = np.zeros((num_transitions, 1))
+
+ if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
+ if game_overs[-1]:
+ R = 0
+ else:
+ R = self.main_network.online_network.predict(np.expand_dims(next_states[-1], 0))[0]
+
+ for i in reversed(range(num_transitions)):
+ R = rewards[i] + self.tp.agent.discount * R
+ state_value_head_targets[i] = R
+ action_advantages[i] = R - current_state_values[i]
+
+ elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
+ # get bootstraps
+ bootstrapped_value = self.main_network.online_network.predict(np.expand_dims(next_states[-1], 0))[0]
+ values = np.append(current_state_values, bootstrapped_value)
+ if game_overs[-1]:
+ values[-1] = 0
+
+ # get general discounted returns table
+ gae_values, state_value_head_targets = self.get_general_advantage_estimation_values(rewards, values)
+ action_advantages = np.vstack(gae_values)
+ else:
+ screen.warning("WARNING: The requested policy gradient rescaler is not available")
+
+ action_advantages = action_advantages.squeeze(axis=-1)
+ if not self.env.discrete_controls and len(actions.shape) < 2:
+ actions = np.expand_dims(actions, -1)
+
+ # train
+ result = self.main_network.online_network.accumulate_gradients([current_states, actions],
+ [state_value_head_targets, action_advantages])
+
+ # logging
+ total_loss, losses, unclipped_grads = result[:3]
+ self.action_advantages.add_sample(action_advantages)
+ self.unclipped_grads.add_sample(unclipped_grads)
+ logger.create_signal_value('Value Loss', losses[0])
+ logger.create_signal_value('Policy Loss', losses[1])
+
+ return total_loss
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ # convert to batch so we can run it through the network
+ observation = np.expand_dims(np.array(curr_state['observation']), 0)
+ if self.env.discrete_controls:
+ # DISCRETE
+ state_value, action_probabilities = self.main_network.online_network.predict(observation)
+ action_probabilities = action_probabilities.squeeze()
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(action_probabilities)
+ else:
+ action = np.argmax(action_probabilities)
+ action_info = {"action_probability": action_probabilities[action], "state_value": state_value}
+ self.entropy.add_sample(-np.sum(action_probabilities * np.log(action_probabilities)))
+ else:
+ # CONTINUOUS
+ state_value, action_values_mean, action_values_std = self.main_network.online_network.predict(observation)
+ action_values_mean = action_values_mean.squeeze()
+ action_values_std = action_values_std.squeeze()
+ if phase == RunPhase.TRAIN:
+ action = np.squeeze(np.random.randn(1, self.action_space_size) * action_values_std + action_values_mean)
+ else:
+ action = action_values_mean
+ action_info = {"action_probability": action, "state_value": state_value}
+
+ return action, action_info
diff --git a/agents/agent.py b/agents/agent.py
new file mode 100644
index 0000000..0279efd
--- /dev/null
+++ b/agents/agent.py
@@ -0,0 +1,536 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import scipy.ndimage
+import matplotlib.pyplot as plt
+import copy
+from configurations import Preset
+from collections import OrderedDict
+from utils import RunPhase, Signal, is_empty, RunningStat
+from architectures import *
+from exploration_policies import *
+from memories import *
+from memories.memory import *
+from logger import logger, screen
+import random
+import time
+import os
+import itertools
+from architectures.tensorflow_components.shared_variables import SharedRunningStats
+from six.moves import range
+
+
+class Agent:
+ def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0):
+ """
+ :param env: An environment instance
+ :type env: EnvironmentWrapper
+ :param tuning_parameters: A Preset class instance with all the running paramaters
+ :type tuning_parameters: Preset
+ :param replicated_device: A tensorflow device for distributed training (optional)
+ :type replicated_device: instancemethod
+ :param thread_id: The current thread id
+ :param thread_id: int
+ """
+
+ screen.log_title("Creating agent {}".format(task_id))
+ self.task_id = task_id
+ self.sess = tuning_parameters.sess
+ self.env = tuning_parameters.env_instance = env
+
+ # i/o dimensions
+ if not tuning_parameters.env.desired_observation_width or not tuning_parameters.env.desired_observation_height:
+ tuning_parameters.env.desired_observation_width = self.env.width
+ tuning_parameters.env.desired_observation_height = self.env.height
+ self.action_space_size = tuning_parameters.env.action_space_size = self.env.action_space_size
+ self.measurements_size = tuning_parameters.env.measurements_size = self.env.measurements_size
+ if tuning_parameters.agent.use_accumulated_reward_as_measurement:
+ self.measurements_size = tuning_parameters.env.measurements_size = (self.measurements_size[0] + 1,)
+
+ # modules
+ self.memory = eval(tuning_parameters.memory + '(tuning_parameters)')
+ # self.architecture = eval(tuning_parameters.architecture)
+
+ self.has_global = replicated_device is not None
+ self.replicated_device = replicated_device
+ self.worker_device = "/job:worker/task:{}/cpu:0".format(task_id) if replicated_device is not None else "/gpu:0"
+
+ self.exploration_policy = eval(tuning_parameters.exploration.policy + '(tuning_parameters)')
+ self.evaluation_exploration_policy = eval(tuning_parameters.exploration.evaluation_policy
+ + '(tuning_parameters)')
+ self.evaluation_exploration_policy.change_phase(RunPhase.TEST)
+
+ # initialize all internal variables
+ self.tp = tuning_parameters
+ self.in_heatup = False
+ self.total_reward_in_current_episode = 0
+ self.total_steps_counter = 0
+ self.running_reward = None
+ self.training_iteration = 0
+ self.current_episode = 0
+ self.curr_state = []
+ self.current_episode_steps_counter = 0
+ self.episode_running_info = {}
+ self.last_episode_evaluation_ran = 0
+ self.running_observations = []
+ logger.set_current_time(self.current_episode)
+ self.main_network = None
+ self.networks = []
+ self.last_episode_images = []
+
+ # signals
+ self.signals = []
+ self.loss = Signal('Loss')
+ self.signals.append(self.loss)
+ self.curr_learning_rate = Signal('Learning Rate')
+ self.signals.append(self.curr_learning_rate)
+
+ if self.tp.env.normalize_observation and not self.env.is_state_type_image:
+ if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
+ self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
+ self.running_reward_stats = RunningStat(())
+ else:
+ self.running_observation_stats = SharedRunningStats(self.tp, replicated_device,
+ shape=(self.tp.env.desired_observation_width,),
+ name='observation_stats')
+ self.running_reward_stats = SharedRunningStats(self.tp, replicated_device,
+ shape=(),
+ name='reward_stats')
+
+ # env is already reset at this point. Otherwise we're getting an error where you cannot
+ # reset an env which is not done
+ self.reset_game(do_not_reset_env=True)
+
+ # use seed
+ if self.tp.seed is not None:
+ random.seed(self.tp.seed)
+ np.random.seed(self.tp.seed)
+
+ def log_to_screen(self, phase):
+ # log to screen
+ if self.current_episode > 0:
+ if phase == RunPhase.TEST:
+ exploration = self.evaluation_exploration_policy.get_control_param()
+ else:
+ exploration = self.exploration_policy.get_control_param()
+ screen.log_dict(
+ OrderedDict([
+ ("Worker", self.task_id),
+ ("Episode", self.current_episode),
+ ("total reward", self.total_reward_in_current_episode),
+ ("exploration", exploration),
+ ("steps", self.total_steps_counter),
+ ("training iteration", self.training_iteration)
+ ]),
+ prefix="Heatup" if self.in_heatup else "Training" if phase == RunPhase.TRAIN else "Testing"
+ )
+
+ def update_log(self, phase=RunPhase.TRAIN):
+ """
+ Writes logging messages to screen and updates the log file with all the signal values.
+ :return: None
+ """
+ # log all the signals to file
+ logger.set_current_time(self.current_episode)
+ logger.create_signal_value('Training Iter', self.training_iteration)
+ logger.create_signal_value('In Heatup', int(self.in_heatup))
+ logger.create_signal_value('ER #Transitions', self.memory.num_transitions())
+ logger.create_signal_value('ER #Episodes', self.memory.length())
+ logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
+ logger.create_signal_value('Total steps', self.total_steps_counter)
+ logger.create_signal_value("Epsilon", self.exploration_policy.get_control_param())
+ if phase == RunPhase.TRAIN:
+ logger.create_signal_value("Training Reward", self.total_reward_in_current_episode)
+ elif phase == RunPhase.TEST:
+ logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode)
+ logger.update_wall_clock_time(self.current_episode)
+
+ for signal in self.signals:
+ logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
+ logger.create_signal_value("{}/Stdev".format(signal.name), signal.get_stdev())
+ logger.create_signal_value("{}/Max".format(signal.name), signal.get_max())
+ logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
+
+ # dump
+ if self.current_episode % self.tp.visualization.dump_signals_to_csv_every_x_episodes == 0:
+ logger.dump_output_csv()
+
+ def reset_game(self, do_not_reset_env=False):
+ """
+ Resets all the episodic parameters and start a new environment episode.
+ :param do_not_reset_env: A boolean that allows prevention of environment reset
+ :return: None
+ """
+
+ for signal in self.signals:
+ signal.reset()
+ self.total_reward_in_current_episode = 0
+ self.curr_state = []
+ self.last_episode_images = []
+ self.current_episode_steps_counter = 0
+ self.episode_running_info = {}
+ if not do_not_reset_env:
+ self.env.reset()
+ self.exploration_policy.reset()
+
+ # required for online plotting
+ if self.tp.visualization.plot_action_values_online:
+ if hasattr(self, 'episode_running_info') and hasattr(self.env, 'actions_description'):
+ for action in self.env.actions_description:
+ self.episode_running_info[action] = []
+ plt.clf()
+ if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
+ for network in self.networks:
+ network.curr_rnn_c_in = network.middleware_embedder.c_init
+ network.curr_rnn_h_in = network.middleware_embedder.h_init
+
+ def stack_observation(self, curr_stack, observation):
+ """
+ Adds a new observation to an existing stack of observations from previous time-steps.
+ :param curr_stack: The current observations stack.
+ :param observation: The new observation
+ :return: The updated observation stack
+ """
+
+ if curr_stack == []:
+ # starting an episode
+ curr_stack = np.vstack(np.expand_dims([observation] * self.tp.env.observation_stack_size, 0))
+ curr_stack = self.switch_axes_order(curr_stack, from_type='channels_first', to_type='channels_last')
+ else:
+ curr_stack = np.append(curr_stack, np.expand_dims(np.squeeze(observation), axis=-1), axis=-1)
+ curr_stack = np.delete(curr_stack, 0, -1)
+
+ return curr_stack
+
+ def preprocess_observation(self, observation):
+ """
+ Preprocesses the given observation.
+ For images - convert to grayscale, resize and convert to int.
+ For measurements vectors - normalize by a running average and std.
+ :param observation: The agents observation
+ :return: A processed version of the observation
+ """
+
+ if self.env.is_state_type_image:
+ # rescale
+ observation = scipy.misc.imresize(observation,
+ (self.tp.env.desired_observation_height,
+ self.tp.env.desired_observation_width),
+ interp=self.tp.rescaling_interpolation_type)
+ # rgb to y
+ if len(observation.shape) > 2 and observation.shape[2] > 1:
+ r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
+ observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
+
+ return observation.astype('uint8')
+ else:
+ if self.tp.env.normalize_observation:
+ # standardize the input observation using a running mean and std
+ if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
+ self.running_observation_stats.push(observation)
+ observation = (observation - self.running_observation_stats.mean) / \
+ (self.running_observation_stats.std + 1e-15)
+ observation = np.clip(observation, -5.0, 5.0)
+ return observation
+
+ def learn_from_batch(self, batch):
+ """
+ Given a batch of transitions, calculates their target values and updates the network.
+ :param batch: A list of transitions
+ :return: The loss of the training
+ """
+ pass
+
+ def train(self):
+ """
+ A single training iteration. Sample a batch, train on it and update target networks.
+ :return: The training loss.
+ """
+ batch = self.memory.sample(self.tp.batch_size)
+ loss = self.learn_from_batch(batch)
+
+ if self.tp.learning_rate_decay_rate != 0:
+ self.curr_learning_rate.add_sample(self.tp.sess.run(self.tp.learning_rate))
+ else:
+ self.curr_learning_rate.add_sample(self.tp.learning_rate)
+
+ # update the target network of every network that has a target network
+ if self.total_steps_counter % self.tp.agent.num_steps_between_copying_online_weights_to_target == 0:
+ for network in self.networks:
+ network.update_target_network(self.tp.agent.rate_for_copying_weights_to_target)
+ logger.create_signal_value('Update Target Network', 1)
+ else:
+ logger.create_signal_value('Update Target Network', 0, overwrite=False)
+
+ return loss
+
+ def extract_batch(self, batch):
+ """
+ Extracts a single numpy array for each object in a batch of transitions (state, action, etc.)
+ :param batch: An array of transitions
+ :return: For each transition element, returns a numpy array of all the transitions in the batch
+ """
+
+ current_observations = np.array([transition.state['observation'] for transition in batch])
+ next_observations = np.array([transition.next_state['observation'] for transition in batch])
+ actions = np.array([transition.action for transition in batch])
+ rewards = np.array([transition.reward for transition in batch])
+ game_overs = np.array([transition.game_over for transition in batch])
+ total_return = np.array([transition.total_return for transition in batch])
+
+ current_states = current_observations
+ next_states = next_observations
+
+ # get the entire state including measurements if available
+ if self.tp.agent.use_measurements:
+ current_measurements = np.array([transition.state['measurements'] for transition in batch])
+ next_measurements = np.array([transition.next_state['measurements'] for transition in batch])
+ current_states = [current_observations, current_measurements]
+ next_states = [next_observations, next_measurements]
+
+ return current_states, next_states, actions, rewards, game_overs, total_return
+
+ def plot_action_values_online(self):
+ """
+ Plot an animated graph of the value of each possible action during the episode
+ :return: None
+ """
+
+ plt.clf()
+ for key, data_list in self.episode_running_info.items():
+ plt.plot(data_list, label=key)
+ plt.legend()
+ plt.pause(0.00000001)
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ """
+ choose an action to act with in the current episode being played. Different behavior might be exhibited when training
+ or testing.
+
+ :param curr_state: the current state to act upon.
+ :param phase: the current phase: training or testing.
+ :return: chosen action, some action value describing the action (q-value, probability, etc)
+ """
+ pass
+
+ def preprocess_reward(self, reward):
+ if self.tp.env.reward_scaling:
+ reward /= float(self.tp.env.reward_scaling)
+ if self.tp.env.reward_clipping_max:
+ reward = min(reward, self.tp.env.reward_clipping_max)
+ if self.tp.env.reward_clipping_min:
+ reward = max(reward, self.tp.env.reward_clipping_min)
+ return reward
+
+ def switch_axes_order(self, observation, from_type='channels_first', to_type='channels_last'):
+ """
+ transpose an observation axes from channels_first to channels_last or vice versa
+ :param observation: a numpy array
+ :param from_type: can be 'channels_first' or 'channels_last'
+ :param to_type: can be 'channels_first' or 'channels_last'
+ :return: a new observation with the requested axes order
+ """
+ if from_type == to_type:
+ return
+ assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image'
+ assert type(observation) == np.ndarray, 'observation must be a numpy array'
+ if len(observation.shape) == 3:
+ if from_type == 'channels_first' and to_type == 'channels_last':
+ return np.transpose(observation, (1, 2, 0))
+ elif from_type == 'channels_last' and to_type == 'channels_first':
+ return np.transpose(observation, (2, 0, 1))
+ else:
+ return np.transpose(observation, (1, 0))
+
+ def act(self, phase=RunPhase.TRAIN):
+ """
+ Take one step in the environment according to the network prediction and store the transition in memory
+ :param phase: Either Train or Test to specify if greedy actions should be used and if transitions should be stored
+ :return: A boolean value that signals an episode termination
+ """
+
+ self.total_steps_counter += 1
+ self.current_episode_steps_counter += 1
+
+ # get new action
+ action_info = {"action_probability": 1.0 / self.env.action_space_size, "action_value": 0}
+ is_first_transition_in_episode = (self.curr_state == [])
+ if is_first_transition_in_episode:
+ observation = self.preprocess_observation(self.env.observation)
+ observation = self.stack_observation([], observation)
+
+ self.curr_state = {'observation': observation}
+ if self.tp.agent.use_measurements:
+ self.curr_state['measurements'] = self.env.measurements
+ if self.tp.agent.use_accumulated_reward_as_measurement:
+ self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
+
+ if self.in_heatup: # we do not have a stacked curr_state yet
+ action = self.env.get_random_action()
+ else:
+ action, action_info = self.choose_action(self.curr_state, phase=phase)
+
+ # perform action
+ if type(action) == np.ndarray:
+ action = action.squeeze()
+ result = self.env.step(action)
+ shaped_reward = self.preprocess_reward(result['reward'])
+ if 'action_intrinsic_reward' in action_info.keys():
+ shaped_reward += action_info['action_intrinsic_reward']
+ self.total_reward_in_current_episode += result['reward']
+ observation = self.preprocess_observation(result['observation'])
+
+ # plot action values online
+ if self.tp.visualization.plot_action_values_online and not self.in_heatup:
+ self.plot_action_values_online()
+
+ # initialize the next state
+ observation = self.stack_observation(self.curr_state['observation'], observation)
+
+ next_state = {'observation': observation}
+ if self.tp.agent.use_measurements and 'measurements' in result.keys():
+ next_state['measurements'] = result['measurements']
+ if self.tp.agent.use_accumulated_reward_as_measurement:
+ next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)
+
+ # store the transition only if we are training
+ if phase == RunPhase.TRAIN:
+ transition = Transition(self.curr_state, result['action'], shaped_reward, next_state, result['done'])
+ for key in action_info.keys():
+ transition.info[key] = action_info[key]
+ if self.tp.agent.add_a_normalized_timestep_to_the_observation:
+ transition.info['timestep'] = float(self.current_episode_steps_counter) / self.env.timestep_limit
+ self.memory.store(transition)
+ elif phase == RunPhase.TEST and self.tp.visualization.dump_gifs:
+ # we store the transitions only for saving gifs
+ self.last_episode_images.append(self.env.get_rendered_image())
+
+ # update the current state for the next step
+ self.curr_state = next_state
+
+ # deal with episode termination
+ if result['done']:
+ if self.tp.visualization.dump_csv:
+ self.update_log(phase=phase)
+ self.log_to_screen(phase=phase)
+
+ if phase == RunPhase.TRAIN:
+ self.reset_game()
+
+ self.current_episode += 1
+
+ # return episode really ended
+ return result['done']
+
+ def evaluate(self, num_episodes, keep_networks_synced=False):
+ """
+ Run in an evaluation mode for several episodes. Actions will be chosen greedily.
+ :param keep_networks_synced: keep the online network in sync with the global network after every episode
+ :param num_episodes: The number of episodes to evaluate on
+ :return: None
+ """
+
+ max_reward_achieved = -float('inf')
+ average_evaluation_reward = 0
+ screen.log_title("Running evaluation")
+ self.env.change_phase(RunPhase.TEST)
+ for i in range(num_episodes):
+ # keep the online network in sync with the global network
+ if keep_networks_synced:
+ for network in self.networks:
+ network.sync()
+
+ episode_ended = False
+ while not episode_ended:
+ episode_ended = self.act(phase=RunPhase.TEST)
+
+ if self.tp.visualization.dump_gifs and self.total_reward_in_current_episode > max_reward_achieved:
+ max_reward_achieved = self.total_reward_in_current_episode
+ frame_skipping = int(5/self.tp.env.frame_skip)
+ logger.create_gif(self.last_episode_images[::frame_skipping],
+ name='score-{}'.format(max_reward_achieved), fps=10)
+
+ average_evaluation_reward += self.total_reward_in_current_episode
+ self.reset_game()
+
+ average_evaluation_reward /= float(num_episodes)
+
+ self.env.change_phase(RunPhase.TRAIN)
+ screen.log_title("Evaluation done. Average reward = {}.".format(average_evaluation_reward))
+
+ def post_training_commands(self):
+ pass
+
+ def improve(self):
+ """
+ Training algorithms wrapper. Heatup >> [ Evaluate >> Play >> Train >> Save checkpoint ]
+
+ :return: None
+ """
+
+ # synchronize the online network weights with the global network
+ for network in self.networks:
+ network.sync()
+
+ # heatup phase
+ if self.tp.num_heatup_steps != 0:
+ self.in_heatup = True
+ screen.log_title("Starting heatup {}".format(self.task_id))
+ num_steps_required_for_one_training_batch = self.tp.batch_size * self.tp.env.observation_stack_size
+ for step in range(max(self.tp.num_heatup_steps, num_steps_required_for_one_training_batch)):
+ self.act()
+
+ # training phase
+ self.in_heatup = False
+ screen.log_title("Starting training {}".format(self.task_id))
+ self.exploration_policy.change_phase(RunPhase.TRAIN)
+ training_start_time = time.time()
+ model_snapshots_periods_passed = -1
+
+ while self.training_iteration < self.tp.num_training_iterations:
+ # evaluate
+ evaluate_agent = (self.last_episode_evaluation_ran is not self.current_episode) and \
+ (self.current_episode % self.tp.evaluate_every_x_episodes == 0)
+ if evaluate_agent:
+ self.last_episode_evaluation_ran = self.current_episode
+ self.evaluate(self.tp.evaluation_episodes)
+
+ # snapshot model
+ if self.tp.save_model_sec and self.tp.save_model_sec > 0 and not self.tp.distributed:
+ total_training_time = time.time() - training_start_time
+ current_snapshot_period = (int(total_training_time) // self.tp.save_model_sec)
+ if current_snapshot_period > model_snapshots_periods_passed:
+ model_snapshots_periods_passed = current_snapshot_period
+ self.main_network.save_model(model_snapshots_periods_passed)
+
+ # play and record in replay buffer
+ if self.tp.agent.step_until_collecting_full_episodes:
+ step = 0
+ while step < self.tp.agent.num_consecutive_playing_steps or self.memory.get_episode(-1).length() != 0:
+ self.act()
+ step += 1
+ else:
+ for step in range(self.tp.agent.num_consecutive_playing_steps):
+ self.act()
+
+ # train
+ if self.tp.train:
+ for step in range(self.tp.agent.num_consecutive_training_steps):
+ loss = self.train()
+ self.loss.add_sample(loss)
+ self.training_iteration += 1
+ self.post_training_commands()
+
diff --git a/agents/bootstrapped_dqn_agent.py b/agents/bootstrapped_dqn_agent.py
new file mode 100644
index 0000000..3476022
--- /dev/null
+++ b/agents/bootstrapped_dqn_agent.py
@@ -0,0 +1,58 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+
+
+# Bootstrapped DQN - https://arxiv.org/pdf/1602.04621.pdf
+class BootstrappedDQNAgent(ValueOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+
+ def reset_game(self, do_not_reset_env=False):
+ ValueOptimizationAgent.reset_game(self, do_not_reset_env)
+ self.exploration_policy.select_head()
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
+
+ # for the action we actually took, the error is:
+ # TD error = r + discount*max(q_st_plus_1) - q_st
+ # for all other actions, the error is 0
+ q_st_plus_1 = self.main_network.target_network.predict(next_states)
+ # initialize with the current prediction so that we will
+ TD_targets = self.main_network.online_network.predict(current_states)
+
+ # only update the action that we have actually done in this transition
+ for i in range(self.tp.batch_size):
+ mask = batch[i].info['mask']
+ for head_idx in range(self.tp.exploration.architecture_num_q_heads):
+ if mask[head_idx] == 1:
+ TD_targets[head_idx][i, actions[i]] = rewards[i] + \
+ (1.0 - game_overs[i]) * self.tp.agent.discount * np.max(
+ q_st_plus_1[head_idx][i], 0)
+
+ result = self.main_network.train_and_sync_networks(current_states, TD_targets)
+
+ total_loss = result[0]
+
+ return total_loss
+
+ def act(self, phase=RunPhase.TRAIN):
+ ValueOptimizationAgent.act(self, phase)
+ mask = np.random.binomial(1, self.tp.exploration.bootstrapped_data_sharing_probability,
+ self.tp.exploration.architecture_num_q_heads)
+ self.memory.update_last_transition_info({'mask': mask})
diff --git a/agents/clipped_ppo_agent.py b/agents/clipped_ppo_agent.py
new file mode 100644
index 0000000..cadfc0b
--- /dev/null
+++ b/agents/clipped_ppo_agent.py
@@ -0,0 +1,210 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.actor_critic_agent import *
+from random import shuffle
+import tensorflow as tf
+
+
+# Clipped Proximal Policy Optimization - https://arxiv.org/abs/1707.06347
+class ClippedPPOAgent(ActorCriticAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ActorCriticAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id,
+ create_target_network=True)
+ # signals definition
+ self.value_loss = Signal('Value Loss')
+ self.signals.append(self.value_loss)
+ self.policy_loss = Signal('Policy Loss')
+ self.signals.append(self.policy_loss)
+ self.total_kl_divergence_during_training_process = 0.0
+ self.unclipped_grads = Signal('Grads (unclipped)')
+ self.signals.append(self.unclipped_grads)
+ self.value_targets = Signal('Value Targets')
+ self.signals.append(self.value_targets)
+ self.kl_divergence = Signal('KL Divergence')
+ self.signals.append(self.kl_divergence)
+
+ def fill_advantages(self, batch):
+ current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch)
+
+ current_state_values = self.main_network.online_network.predict([current_states])[0]
+ current_state_values = current_state_values.squeeze()
+ self.state_values.add_sample(current_state_values)
+
+ # calculate advantages
+ advantages = []
+ value_targets = []
+ if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
+ advantages = total_return - current_state_values
+ elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
+ # get bootstraps
+ episode_start_idx = 0
+ advantages = np.array([])
+ value_targets = np.array([])
+ for idx, game_over in enumerate(game_overs):
+ if game_over:
+ # get advantages for the rollout
+ value_bootstrapping = np.zeros((1,))
+ rollout_state_values = np.append(current_state_values[episode_start_idx:idx+1], value_bootstrapping)
+
+ rollout_advantages, gae_based_value_targets = \
+ self.get_general_advantage_estimation_values(rewards[episode_start_idx:idx+1],
+ rollout_state_values)
+ episode_start_idx = idx + 1
+ advantages = np.append(advantages, rollout_advantages)
+ value_targets = np.append(value_targets, gae_based_value_targets)
+ else:
+ screen.warning("WARNING: The requested policy gradient rescaler is not available")
+
+ # standardize
+ advantages = (advantages - np.mean(advantages)) / np.std(advantages)
+
+ for transition, advantage, value_target in zip(batch, advantages, value_targets):
+ transition.info['advantage'] = advantage
+ transition.info['gae_based_value_target'] = value_target
+
+ self.action_advantages.add_sample(advantages)
+
+ def train_network(self, dataset, epochs):
+ loss = []
+ for j in range(epochs):
+ loss = {
+ 'total_loss': [],
+ 'policy_losses': [],
+ 'unclipped_grads': [],
+ 'fetch_result': []
+ }
+ shuffle(dataset)
+ for i in range(int(len(dataset) / self.tp.batch_size)):
+ batch = dataset[i * self.tp.batch_size:(i + 1) * self.tp.batch_size]
+ current_states, _, actions, _, _, total_return = self.extract_batch(batch)
+
+ advantages = np.array([t.info['advantage'] for t in batch])
+ gae_based_value_targets = np.array([t.info['gae_based_value_target'] for t in batch])
+ if not self.tp.env_instance.discrete_controls and len(actions.shape) == 1:
+ actions = np.expand_dims(actions, -1)
+
+ # get old policy probabilities and distribution
+ result = self.main_network.target_network.predict([current_states])
+ old_policy_distribution = result[1:]
+
+ # calculate gradients and apply on both the local policy network and on the global policy network
+ fetches = [self.main_network.online_network.output_heads[1].kl_divergence,
+ self.main_network.online_network.output_heads[1].entropy]
+
+ total_return = np.expand_dims(total_return, -1)
+ value_targets = gae_based_value_targets if self.tp.agent.estimate_value_using_gae else total_return
+ total_loss, policy_losses, unclipped_grads, fetch_result =\
+ self.main_network.online_network.accumulate_gradients(
+ [current_states] + [actions] + old_policy_distribution,
+ [total_return, advantages], additional_fetches=fetches)
+
+ self.value_targets.add_sample(value_targets)
+ if self.tp.distributed:
+ self.main_network.apply_gradients_to_global_network()
+ self.main_network.update_online_network()
+ else:
+ self.main_network.apply_gradients_to_online_network()
+
+ self.main_network.online_network.reset_accumulated_gradients()
+
+ loss['total_loss'].append(total_loss)
+ loss['policy_losses'].append(policy_losses)
+ loss['unclipped_grads'].append(unclipped_grads)
+ loss['fetch_result'].append(fetch_result)
+
+ self.unclipped_grads.add_sample(unclipped_grads)
+
+ for key in loss.keys():
+ loss[key] = np.mean(loss[key], 0)
+
+ if self.tp.learning_rate_decay_rate != 0:
+ curr_learning_rate = self.tp.sess.run(self.tp.learning_rate)
+ self.curr_learning_rate.add_sample(curr_learning_rate)
+ else:
+ curr_learning_rate = self.tp.learning_rate
+
+ # log training parameters
+ screen.log_dict(
+ OrderedDict([
+ ("Surrogate loss", loss['policy_losses'][0]),
+ ("KL divergence", loss['fetch_result'][0]),
+ ("Entropy", loss['fetch_result'][1]),
+ ("training epoch", j),
+ ("learning_rate", curr_learning_rate)
+ ]),
+ prefix="Policy training"
+ )
+
+ self.total_kl_divergence_during_training_process = loss['fetch_result'][0]
+ self.entropy.add_sample(loss['fetch_result'][1])
+ self.kl_divergence.add_sample(loss['fetch_result'][0])
+ return policy_losses
+
+ def post_training_commands(self):
+
+ # clean memory
+ self.memory.clean()
+
+ def train(self):
+ self.main_network.sync()
+
+ dataset = self.memory.transitions
+
+ self.fill_advantages(dataset)
+
+ # take only the requested number of steps
+ dataset = dataset[:self.tp.agent.num_consecutive_playing_steps]
+
+ if self.tp.distributed and self.tp.agent.share_statistics_between_workers:
+ self.running_observation_stats.push(np.array([t.state['observation'] for t in dataset]))
+
+ losses = self.train_network(dataset, 10)
+ self.value_loss.add_sample(losses[0])
+ self.policy_loss.add_sample(losses[1])
+ self.update_log() # should be done in order to update the data that has been accumulated * while not playing *
+ return np.append(losses[0], losses[1])
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ # convert to batch so we can run it through the network
+ observation = curr_state['observation']
+ observation = np.expand_dims(np.array(observation), 0)
+
+ if self.env.discrete_controls:
+ # DISCRETE
+ _, action_values = self.main_network.online_network.predict(observation)
+ action_values = action_values.squeeze()
+
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(action_values)
+ else:
+ action = np.argmax(action_values)
+ action_info = {"action_probability": action_values[action]}
+ # self.entropy.add_sample(-np.sum(action_values * np.log(action_values)))
+ else:
+ # CONTINUOUS
+ _, action_values_mean, action_values_std = self.main_network.online_network.predict(observation)
+ action_values_mean = action_values_mean.squeeze()
+ action_values_std = action_values_std.squeeze()
+ if phase == RunPhase.TRAIN:
+ action = np.squeeze(np.random.randn(1, self.action_space_size) * action_values_std + action_values_mean)
+ # if self.current_episode % 5 == 0 and self.current_episode_steps_counter < 5:
+ # print action
+ else:
+ action = action_values_mean
+ action_info = {"action_probability": action_values_mean}
+
+ return action, action_info
diff --git a/agents/ddpg_agent.py b/agents/ddpg_agent.py
new file mode 100644
index 0000000..f5d0275
--- /dev/null
+++ b/agents/ddpg_agent.py
@@ -0,0 +1,104 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.actor_critic_agent import *
+from configurations import *
+
+
+# Deep Deterministic Policy Gradients Network - https://arxiv.org/pdf/1509.02971.pdf
+class DDPGAgent(ActorCriticAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ActorCriticAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id,
+ create_target_network=True)
+ # define critic network
+ self.critic_network = self.main_network
+ # self.networks.append(self.critic_network)
+
+ # define actor network
+ tuning_parameters.agent.input_types = [InputTypes.Observation]
+ tuning_parameters.agent.output_types = [OutputTypes.Pi]
+ self.actor_network = NetworkWrapper(tuning_parameters, True, self.has_global, 'actor',
+ self.replicated_device, self.worker_device)
+ self.networks.append(self.actor_network)
+
+ self.q_values = Signal("Q")
+ self.signals.append(self.q_values)
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
+
+ # TD error = r + discount*max(q_st_plus_1) - q_st
+ next_actions = self.actor_network.target_network.predict([next_states])
+ q_st_plus_1 = self.critic_network.target_network.predict([next_states, next_actions])
+ TD_targets = np.expand_dims(rewards, -1) + \
+ (1.0 - np.expand_dims(game_overs, -1)) * self.tp.agent.discount * q_st_plus_1
+
+ # get the gradients of the critic output with respect to the action
+ actions_mean = self.actor_network.online_network.predict(current_states)
+ critic_online_network = self.critic_network.online_network
+ action_gradients = self.critic_network.sess.run(critic_online_network.gradients_wrt_inputs[1],
+ feed_dict={
+ critic_online_network.inputs[0]: current_states,
+ critic_online_network.inputs[1]: actions_mean,
+ })[0]
+
+ # train the critic
+ if len(actions.shape) == 1:
+ actions = np.expand_dims(actions, -1)
+ result = self.critic_network.train_and_sync_networks([current_states, actions], TD_targets)
+ total_loss = result[0]
+
+ # apply the gradients from the critic to the actor
+ actor_online_network = self.actor_network.online_network
+ gradients = self.actor_network.sess.run(actor_online_network.weighted_gradients,
+ feed_dict={
+ actor_online_network.gradients_weights_ph: -action_gradients,
+ actor_online_network.inputs[0]: current_states
+ })
+ if self.actor_network.has_global:
+ self.actor_network.global_network.apply_gradients(gradients)
+ self.actor_network.update_online_network()
+ else:
+ self.actor_network.online_network.apply_gradients(gradients)
+
+ return total_loss
+
+ def train(self):
+ return Agent.train(self)
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ assert not self.env.discrete_controls, 'DDPG works only for continuous control problems'
+ # convert to batch so we can run it through the network
+ observation = np.expand_dims(np.array(curr_state['observation']), 0)
+ result = self.actor_network.online_network.predict(observation)
+ action_values = result[0].squeeze()
+
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(action_values)
+ else:
+ action = action_values
+
+ action = np.clip(action, self.env.action_space_low, self.env.action_space_high)
+
+ # get q value
+ action_batch = np.expand_dims(action, 0)
+ if type(action) != np.ndarray:
+ action_batch = np.array([[action]])
+ q_value = self.critic_network.online_network.predict([observation, action_batch])[0]
+ self.q_values.add_sample(q_value)
+ action_info = {"action_value": q_value}
+
+ return action, action_info
diff --git a/agents/ddqn_agent.py b/agents/ddqn_agent.py
new file mode 100644
index 0000000..838ae3f
--- /dev/null
+++ b/agents/ddqn_agent.py
@@ -0,0 +1,42 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+
+
+# Double DQN - https://arxiv.org/abs/1509.06461
+class DDQNAgent(ValueOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
+
+ selected_actions = np.argmax(self.main_network.online_network.predict(next_states), 1)
+ q_st_plus_1 = self.main_network.target_network.predict(next_states)
+ TD_targets = self.main_network.online_network.predict(current_states)
+
+ # initialize with the current prediction so that we will
+ # only update the action that we have actually done in this transition
+ for i in range(self.tp.batch_size):
+ TD_targets[i, actions[i]] = rewards[i] \
+ + (1.0 - game_overs[i]) * self.tp.agent.discount * q_st_plus_1[i][
+ selected_actions[i]]
+
+ result = self.main_network.train_and_sync_networks(current_states, TD_targets)
+ total_loss = result[0]
+
+ return total_loss
diff --git a/agents/dfp_agent.py b/agents/dfp_agent.py
new file mode 100644
index 0000000..2205aa6
--- /dev/null
+++ b/agents/dfp_agent.py
@@ -0,0 +1,83 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.agent import *
+
+
+# Direct Future Prediction Agent - http://vladlen.info/papers/learning-to-act.pdf
+class DFPAgent(Agent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+ self.current_goal = self.tp.agent.goal_vector
+ self.main_network = NetworkWrapper(tuning_parameters, False, self.has_global, 'main',
+ self.replicated_device, self.worker_device)
+ self.networks.append(self.main_network)
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, total_returns = self.extract_batch(batch)
+
+ # create the inputs for the network
+ input = current_states
+ input.append(np.repeat(np.expand_dims(self.current_goal, 0), self.tp.batch_size, 0))
+
+ # get the current outputs of the network
+ targets = self.main_network.online_network.predict(input)
+
+ # change the targets for the taken actions
+ for i in range(self.tp.batch_size):
+ targets[i, actions[i]] = batch[i].info['future_measurements'].flatten()
+
+ result = self.main_network.train_and_sync_networks(current_states, targets)
+ total_loss = result[0]
+
+ return total_loss
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ # convert to batch so we can run it through the network
+ observation = np.expand_dims(np.array(curr_state['observation']), 0)
+ measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
+ goal = np.expand_dims(self.current_goal, 0)
+
+ # predict the future measurements
+ measurements_future_prediction = self.main_network.online_network.predict([observation, measurements, goal])[0]
+ action_values = np.zeros((self.action_space_size,))
+ num_steps_used_for_objective = len(self.tp.agent.future_measurements_weights)
+
+ # calculate the score of each action by multiplying it's future measurements with the goal vector
+ for action_idx in range(self.action_space_size):
+ action_measurements = measurements_future_prediction[action_idx]
+ action_measurements = np.reshape(action_measurements,
+ (self.tp.agent.num_predicted_steps_ahead, self.measurements_size[0]))
+ future_steps_values = np.dot(action_measurements, self.current_goal)
+ action_values[action_idx] = np.dot(future_steps_values[-num_steps_used_for_objective:],
+ self.tp.agent.future_measurements_weights)
+
+ # choose action according to the exploration policy and the current phase (evaluating or training the agent)
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(action_values)
+ else:
+ action = np.argmax(action_values)
+
+ action_values = action_values.squeeze()
+
+ # store information for plotting interactively (actual plotting is done in agent)
+ if self.tp.visualization.plot_action_values_online:
+ for idx, action_name in enumerate(self.env.actions_description):
+ self.episode_running_info[action_name].append(action_values[idx])
+
+ action_info = {"action_probability": 0, "action_value": action_values[action]}
+
+ return action, action_info
diff --git a/agents/distributional_dqn_agent.py b/agents/distributional_dqn_agent.py
new file mode 100644
index 0000000..d7c0088
--- /dev/null
+++ b/agents/distributional_dqn_agent.py
@@ -0,0 +1,60 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+
+
+# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
+class DistributionalDQNAgent(ValueOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+ self.z_values = np.linspace(self.tp.agent.v_min, self.tp.agent.v_max, self.tp.agent.atoms)
+
+ # prediction's format is (batch,actions,atoms)
+ def get_q_values(self, prediction):
+ return np.dot(prediction, self.z_values)
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
+
+ # for the action we actually took, the error is calculated by the atoms distribution
+ # for all other actions, the error is 0
+ distributed_q_st_plus_1 = self.main_network.target_network.predict(next_states)
+ # initialize with the current prediction so that we will
+ TD_targets = self.main_network.online_network.predict(current_states)
+
+ # only update the action that we have actually done in this transition
+ target_actions = np.argmax(self.get_q_values(distributed_q_st_plus_1), axis=1)
+ m = np.zeros((self.tp.batch_size, self.z_values.size))
+
+ batches = np.arange(self.tp.batch_size)
+ for j in range(self.z_values.size):
+ tzj = np.fmax(np.fmin(rewards + (1.0 - game_overs) * self.tp.agent.discount * self.z_values[j],
+ self.z_values[self.z_values.size - 1]),
+ self.z_values[0])
+ bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])
+ u = (np.ceil(bj)).astype(int)
+ l = (np.floor(bj)).astype(int)
+ m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj))
+ m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l))
+ # total_loss = cross entropy between actual result above and predicted result for the given action
+ TD_targets[batches, actions] = m
+
+ result = self.main_network.train_and_sync_networks(current_states, TD_targets)
+ total_loss = result[0]
+
+ return total_loss
+
diff --git a/agents/dqn_agent.py b/agents/dqn_agent.py
new file mode 100644
index 0000000..70c0c7d
--- /dev/null
+++ b/agents/dqn_agent.py
@@ -0,0 +1,43 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+
+
+# Deep Q Network - https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
+class DQNAgent(ValueOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
+
+ # for the action we actually took, the error is:
+ # TD error = r + discount*max(q_st_plus_1) - q_st
+ # for all other actions, the error is 0
+ q_st_plus_1 = self.main_network.target_network.predict(next_states)
+ # initialize with the current prediction so that we will
+ TD_targets = self.main_network.online_network.predict(current_states)
+
+ # only update the action that we have actually done in this transition
+ for i in range(self.tp.batch_size):
+ TD_targets[i, actions[i]] = rewards[i] + (1.0 - game_overs[i]) * self.tp.agent.discount * np.max(
+ q_st_plus_1[i], 0)
+
+ result = self.main_network.train_and_sync_networks(current_states, TD_targets)
+ total_loss = result[0]
+
+ return total_loss
diff --git a/agents/mmc_agent.py b/agents/mmc_agent.py
new file mode 100644
index 0000000..2b5a2cb
--- /dev/null
+++ b/agents/mmc_agent.py
@@ -0,0 +1,42 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+
+
+class MixedMonteCarloAgent(ValueOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+ self.mixing_rate = tuning_parameters.agent.monte_carlo_mixing_rate
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch)
+
+ TD_targets = self.main_network.online_network.predict(current_states)
+ selected_actions = np.argmax(self.main_network.online_network.predict(next_states), 1)
+ q_st_plus_1 = self.main_network.target_network.predict(next_states)
+ # initialize with the current prediction so that we will
+ # only update the action that we have actually done in this transition
+ for i in range(self.tp.batch_size):
+ one_step_target = rewards[i] + (1.0 - game_overs[i]) * self.tp.agent.discount * q_st_plus_1[i][
+ selected_actions[i]]
+ monte_carlo_target = total_return[i]
+ TD_targets[i, actions[i]] = (1 - self.mixing_rate) * one_step_target + self.mixing_rate * monte_carlo_target
+
+ result = self.main_network.train_and_sync_networks(current_states, TD_targets)
+ total_loss = result[0]
+
+ return total_loss
diff --git a/agents/n_step_q_agent.py b/agents/n_step_q_agent.py
new file mode 100644
index 0000000..0746523
--- /dev/null
+++ b/agents/n_step_q_agent.py
@@ -0,0 +1,85 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+from agents.policy_optimization_agent import *
+from logger import *
+from utils import *
+import scipy.signal
+
+
+# N Step Q Learning Agent - https://arxiv.org/abs/1602.01783
+class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id, create_target_network=True)
+ self.last_gradient_update_step_idx = 0
+ self.q_values = Signal('Q Values')
+ self.unclipped_grads = Signal('Grads (unclipped)')
+ self.signals.append(self.q_values)
+ self.signals.append(self.unclipped_grads)
+
+ def learn_from_batch(self, batch):
+ # batch contains a list of episodes to learn from
+ current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
+
+ # get the values for the current states
+ state_value_head_targets = self.main_network.online_network.predict(current_states)
+
+ # the targets for the state value estimator
+ num_transitions = len(game_overs)
+
+ if self.tp.agent.targets_horizon == '1-Step':
+ # 1-Step Q learning
+ q_st_plus_1 = self.main_network.target_network.predict(next_states)
+
+ for i in reversed(xrange(num_transitions)):
+ state_value_head_targets[i][actions[i]] = \
+ rewards[i] + (1.0 - game_overs[i]) * self.tp.agent.discount * np.max(q_st_plus_1[i], 0)
+
+ elif self.tp.agent.targets_horizon == 'N-Step':
+ # N-Step Q learning
+ if game_overs[-1]:
+ R = 0
+ else:
+ R = np.max(self.main_network.target_network.predict(np.expand_dims(next_states[-1], 0)))
+
+ for i in reversed(xrange(num_transitions)):
+ R = rewards[i] + self.tp.agent.discount * R
+ state_value_head_targets[i][actions[i]] = R
+
+ else:
+ assert True, 'The available values for targets_horizon are: 1-Step, N-Step'
+
+ # train
+ result = self.main_network.online_network.accumulate_gradients([current_states], [state_value_head_targets])
+
+ # logging
+ total_loss, losses, unclipped_grads = result[:3]
+ self.unclipped_grads.add_sample(unclipped_grads)
+ logger.create_signal_value('Value Loss', losses[0])
+
+ return total_loss
+
+ def train(self):
+ # update the target network of every network that has a target network
+ if self.total_steps_counter % self.tp.agent.num_steps_between_copying_online_weights_to_target == 0:
+ for network in self.networks:
+ network.update_target_network(self.tp.agent.rate_for_copying_weights_to_target)
+ logger.create_signal_value('Update Target Network', 1)
+ else:
+ logger.create_signal_value('Update Target Network', 0, overwrite=False)
+
+ return PolicyOptimizationAgent.train(self)
diff --git a/agents/naf_agent.py b/agents/naf_agent.py
new file mode 100644
index 0000000..3ef563e
--- /dev/null
+++ b/agents/naf_agent.py
@@ -0,0 +1,75 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+
+
+# Normalized Advantage Functions - https://arxiv.org/pdf/1603.00748.pdf
+class NAFAgent(ValueOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+ self.l_values = Signal("L")
+ self.a_values = Signal("Advantage")
+ self.mu_values = Signal("Action")
+ self.v_values = Signal("V")
+ self.signals += [self.l_values, self.a_values, self.mu_values, self.v_values]
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
+
+ # TD error = r + discount*v_st_plus_1 - q_st
+ v_st_plus_1 = self.main_network.sess.run(self.main_network.target_network.output_heads[0].V,
+ feed_dict={self.main_network.target_network.inputs[0]: next_states})
+ TD_targets = np.expand_dims(rewards, -1) + (1.0 - np.expand_dims(game_overs, -1)) * self.tp.agent.discount * v_st_plus_1
+
+ if len(actions.shape) == 1:
+ actions = np.expand_dims(actions, -1)
+
+ result = self.main_network.train_and_sync_networks([current_states, actions], TD_targets)
+ total_loss = result[0]
+
+ return total_loss
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ assert not self.env.discrete_controls, 'NAF works only for continuous control problems'
+
+ # convert to batch so we can run it through the network
+ observation = np.expand_dims(np.array(curr_state['observation']), 0)
+ naf_head = self.main_network.online_network.output_heads[0]
+ action_values = self.main_network.sess.run(naf_head.mu,
+ feed_dict={self.main_network.online_network.inputs[0]: observation})
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(action_values)
+ else:
+ action = action_values
+
+ Q, L, A, mu, V = self.main_network.sess.run(
+ [naf_head.Q, naf_head.L, naf_head.A, naf_head.mu, naf_head.V],
+ feed_dict={
+ self.main_network.online_network.inputs[0]: observation,
+ self.main_network.online_network.inputs[1]: action_values
+ }
+ )
+
+ # store the q values statistics for logging
+ self.q_values.add_sample(Q)
+ self.l_values.add_sample(L)
+ self.a_values.add_sample(A)
+ self.mu_values.add_sample(mu)
+ self.v_values.add_sample(V)
+
+ action_value = {"action_value": Q}
+ return action, action_value
diff --git a/agents/nec_agent.py b/agents/nec_agent.py
new file mode 100644
index 0000000..4e724c9
--- /dev/null
+++ b/agents/nec_agent.py
@@ -0,0 +1,104 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+
+
+# Neural Episodic Control - https://arxiv.org/pdf/1703.01988.pdf
+class NECAgent(ValueOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id,
+ create_target_network=False)
+ self.current_episode_state_embeddings = []
+ self.current_episode_actions = []
+ self.training_started = False
+
+ def learn_from_batch(self, batch):
+ if not self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
+ return 0
+ else:
+ if not self.training_started:
+ self.training_started = True
+ screen.log_title("Finished collecting initial entries in DND. Starting to train network...")
+
+ current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch)
+ result = self.main_network.train_and_sync_networks([current_states, actions], total_return)
+ total_loss = result[0]
+
+ return total_loss
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ # convert to batch so we can run it through the network
+ observation = np.expand_dims(np.array(curr_state['observation']), 0)
+
+ # get embedding
+ embedding = self.main_network.sess.run(self.main_network.online_network.state_embedding,
+ feed_dict={self.main_network.online_network.inputs[0]: observation})
+ self.current_episode_state_embeddings.append(embedding[0])
+
+ # get action values
+ if self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
+ # if there are enough entries in the DND then we can query it to get the action values
+ actions_q_values = []
+ for action in range(self.action_space_size):
+ feed_dict = {
+ self.main_network.online_network.state_embedding: embedding,
+ self.main_network.online_network.output_heads[0].input[0]: np.array([action])
+ }
+ q_value = self.main_network.sess.run(
+ self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
+ actions_q_values.append(q_value[0])
+ else:
+ # get only the embedding so we can insert it to the DND
+ actions_q_values = [0] * self.action_space_size
+
+ # choose action according to the exploration policy and the current phase (evaluating or training the agent)
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(actions_q_values)
+ self.current_episode_actions.append(action)
+ else:
+ action = np.argmax(actions_q_values)
+
+ # store the q values statistics for logging
+ self.q_values.add_sample(actions_q_values)
+
+ # store information for plotting interactively (actual plotting is done in agent)
+ if self.tp.visualization.plot_action_values_online:
+ for idx, action_name in enumerate(self.env.actions_description):
+ self.episode_running_info[action_name].append(actions_q_values[idx])
+
+ action_value = {"action_value": actions_q_values[action]}
+ return action, action_value
+
+ def reset_game(self, do_not_reset_env=False):
+ ValueOptimizationAgent.reset_game(self, do_not_reset_env)
+
+ # make sure we already have at least one episode
+ if self.memory.num_complete_episodes() >= 1 and not self.in_heatup:
+ # get the last full episode that we have collected
+ episode = self.memory.get(-2)
+ returns = []
+ for i in range(episode.length()):
+ returns.append(episode.get_transition(i).total_return)
+ # Just to deal with the end of heatup where there might be a case where it ends in a middle
+ # of an episode, and thus when getting the episode out of the ER, it will be a complete one whereas
+ # the other statistics collected here, are collected only during training.
+ returns = returns[-len(self.current_episode_actions):]
+ self.main_network.online_network.output_heads[0].DND.add(self.current_episode_state_embeddings,
+ self.current_episode_actions, returns)
+
+ self.current_episode_state_embeddings = []
+ self.current_episode_actions = []
diff --git a/agents/pal_agent.py b/agents/pal_agent.py
new file mode 100644
index 0000000..68ff675
--- /dev/null
+++ b/agents/pal_agent.py
@@ -0,0 +1,65 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+
+
+# Persistent Advantage Learning - https://arxiv.org/pdf/1512.04860.pdf
+class PALAgent(ValueOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+ self.alpha = tuning_parameters.agent.pal_alpha
+ self.persistent = tuning_parameters.agent.persistent_advantage_learning
+ self.monte_carlo_mixing_rate = tuning_parameters.agent.monte_carlo_mixing_rate
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch)
+
+ selected_actions = np.argmax(self.main_network.online_network.predict(next_states), 1)
+
+ # next state values
+ q_st_plus_1_target = self.main_network.target_network.predict(next_states)
+ v_st_plus_1_target = np.max(q_st_plus_1_target, 1)
+
+ # current state values according to online network
+ q_st_online = self.main_network.online_network.predict(current_states)
+
+ # current state values according to target network
+ q_st_target = self.main_network.target_network.predict(current_states)
+ v_st_target = np.max(q_st_target, 1)
+
+ # calculate TD error
+ TD_targets = np.copy(q_st_online)
+ for i in range(self.tp.batch_size):
+ TD_targets[i, actions[i]] = rewards[i] + (1.0 - game_overs[i]) * self.tp.agent.discount * \
+ q_st_plus_1_target[i][selected_actions[i]]
+ advantage_learning_update = v_st_target[i] - q_st_target[i, actions[i]]
+ next_advantage_learning_update = v_st_plus_1_target[i] - q_st_plus_1_target[i, selected_actions[i]]
+ # Persistent Advantage Learning or Regular Advantage Learning
+ if self.persistent:
+ TD_targets[i, actions[i]] -= self.alpha * min(advantage_learning_update, next_advantage_learning_update)
+ else:
+ TD_targets[i, actions[i]] -= self.alpha * advantage_learning_update
+
+ # mixing monte carlo updates
+ monte_carlo_target = total_return[i]
+ TD_targets[i, actions[i]] = (1 - self.monte_carlo_mixing_rate) * TD_targets[i, actions[i]] \
+ + self.monte_carlo_mixing_rate * monte_carlo_target
+
+ result = self.main_network.train_and_sync_networks(current_states, TD_targets)
+ total_loss = result[0]
+
+ return total_loss
diff --git a/agents/policy_gradients_agent.py b/agents/policy_gradients_agent.py
new file mode 100644
index 0000000..bf873d1
--- /dev/null
+++ b/agents/policy_gradients_agent.py
@@ -0,0 +1,87 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.policy_optimization_agent import *
+import numpy as np
+from logger import *
+import tensorflow as tf
+import matplotlib.pyplot as plt
+
+from utils import *
+
+
+class PolicyGradientsAgent(PolicyOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ PolicyOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+
+ self.last_gradient_update_step_idx = 0
+
+ def learn_from_batch(self, batch):
+ # batch contains a list of episodes to learn from
+ current_states, next_states, actions, rewards, game_overs, total_returns = self.extract_batch(batch)
+
+ for i in reversed(range(len(total_returns))):
+ if self.policy_gradient_rescaler == PolicyGradientRescaler.TOTAL_RETURN:
+ total_returns[i] = total_returns[0]
+ elif self.policy_gradient_rescaler == PolicyGradientRescaler.FUTURE_RETURN:
+ # just take the total return as it is
+ pass
+ elif self.policy_gradient_rescaler == PolicyGradientRescaler.FUTURE_RETURN_NORMALIZED_BY_EPISODE:
+ # we can get a single transition episode while playing Doom Basic, causing the std to be 0
+ if self.std_discounted_return != 0:
+ total_returns[i] = (total_returns[i] - self.mean_discounted_return) / self.std_discounted_return
+ else:
+ total_returns[i] = 0
+ elif self.policy_gradient_rescaler == PolicyGradientRescaler.FUTURE_RETURN_NORMALIZED_BY_TIMESTEP:
+ total_returns[i] -= self.mean_return_over_multiple_episodes[i]
+ else:
+ screen.warning("WARNING: The requested policy gradient rescaler is not available")
+
+ targets = total_returns
+ if not self.env.discrete_controls and len(actions.shape) < 2:
+ actions = np.expand_dims(actions, -1)
+
+ logger.create_signal_value('Returns Variance', np.std(total_returns), self.task_id)
+ logger.create_signal_value('Returns Mean', np.mean(total_returns), self.task_id)
+
+ result = self.main_network.online_network.accumulate_gradients([current_states, actions], targets)
+ total_loss = result[0]
+
+ return total_loss
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ # convert to batch so we can run it through the network
+ observation = np.expand_dims(np.array(curr_state['observation']), 0)
+ if self.env.discrete_controls:
+ # DISCRETE
+ action_values = self.main_network.online_network.predict(observation).squeeze()
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(action_values)
+ else:
+ action = np.argmax(action_values)
+ action_value = {"action_probability": action_values[action]}
+ self.entropy.add_sample(-np.sum(action_values * np.log(action_values)))
+ else:
+ # CONTINUOUS
+ result = self.main_network.online_network.predict(observation)
+ action_values = result[0].squeeze()
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(action_values)
+ else:
+ action = action_values
+ action_value = {}
+
+ return action, action_value
diff --git a/agents/policy_optimization_agent.py b/agents/policy_optimization_agent.py
new file mode 100644
index 0000000..c64dbab
--- /dev/null
+++ b/agents/policy_optimization_agent.py
@@ -0,0 +1,121 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.agent import *
+from memories.memory import Episode
+
+
+class PolicyGradientRescaler(Enum):
+ TOTAL_RETURN = 0
+ FUTURE_RETURN = 1
+ FUTURE_RETURN_NORMALIZED_BY_EPISODE = 2
+ FUTURE_RETURN_NORMALIZED_BY_TIMESTEP = 3 # baselined
+ Q_VALUE = 4
+ A_VALUE = 5
+ TD_RESIDUAL = 6
+ DISCOUNTED_TD_RESIDUAL = 7
+ GAE = 8
+
+
+class PolicyOptimizationAgent(Agent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0, create_target_network=False):
+ Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+ self.main_network = NetworkWrapper(tuning_parameters, create_target_network, self.has_global, 'main',
+ self.replicated_device, self.worker_device)
+ self.networks.append(self.main_network)
+
+ self.policy_gradient_rescaler = PolicyGradientRescaler().get(self.tp.agent.policy_gradient_rescaler)
+
+ # statistics for variance reduction
+ self.last_gradient_update_step_idx = 0
+ self.max_episode_length = 100000
+ self.mean_return_over_multiple_episodes = np.zeros(self.max_episode_length)
+ self.num_episodes_where_step_has_been_seen = np.zeros(self.max_episode_length)
+ self.entropy = Signal('Entropy')
+ self.signals.append(self.entropy)
+
+ def log_to_screen(self, phase):
+ # log to screen
+ if self.current_episode > 0:
+ screen.log_dict(
+ OrderedDict([
+ ("Worker", self.task_id),
+ ("Episode", self.current_episode),
+ ("total reward", self.total_reward_in_current_episode),
+ ("steps", self.total_steps_counter),
+ ("training iteration", self.training_iteration)
+ ]),
+ prefix="Heatup" if self.in_heatup else "Training" if phase == RunPhase.TRAIN else "Testing"
+ )
+
+ def update_episode_statistics(self, episode):
+ episode_discounted_returns = []
+ for i in range(episode.length()):
+ transition = episode.get_transition(i)
+ episode_discounted_returns.append(transition.total_return)
+ self.num_episodes_where_step_has_been_seen[i] += 1
+ self.mean_return_over_multiple_episodes[i] -= self.mean_return_over_multiple_episodes[i] / \
+ self.num_episodes_where_step_has_been_seen[i]
+ self.mean_return_over_multiple_episodes[i] += transition.total_return / \
+ self.num_episodes_where_step_has_been_seen[i]
+ self.mean_discounted_return = np.mean(episode_discounted_returns)
+ self.std_discounted_return = np.std(episode_discounted_returns)
+
+ def train(self):
+ if self.memory.length() == 0:
+ return 0
+
+ episode = self.memory.get_episode(0)
+
+ # check if we should calculate gradients or skip
+ episode_ended = self.memory.num_complete_episodes() >= 1
+ num_steps_passed_since_last_update = episode.length() - self.last_gradient_update_step_idx
+ is_t_max_steps_passed = num_steps_passed_since_last_update >= self.tp.agent.num_steps_between_gradient_updates
+ if not (is_t_max_steps_passed or episode_ended):
+ return 0
+
+ total_loss = 0
+ if num_steps_passed_since_last_update > 0:
+
+ # we need to update the returns of the episode until now
+ episode.update_returns(self.tp.agent.discount)
+
+ # get t_max transitions or less if the we got to a terminal state
+ # will be used for both actor-critic and vanilla PG.
+ # # In order to get full episodes, Vanilla PG will set the end_idx to a very big value.
+ transitions = []
+ start_idx = self.last_gradient_update_step_idx
+ end_idx = episode.length()
+
+ for idx in range(start_idx, end_idx):
+ transitions.append(episode.get_transition(idx))
+ self.last_gradient_update_step_idx = end_idx
+
+ # update the statistics for the variance reduction techniques
+ if self.tp.agent.type == 'PolicyGradientsAgent':
+ self.update_episode_statistics(episode)
+
+ # accumulate the gradients and apply them once in every apply_gradients_every_x_episodes episodes
+ total_loss = self.learn_from_batch(transitions)
+ if self.current_episode % self.tp.agent.apply_gradients_every_x_episodes == 0:
+ self.main_network.apply_gradients_and_sync_networks()
+
+ # move the pointer to the next episode start and discard the episode. we use it only once
+ if episode_ended:
+ self.memory.remove_episode(0)
+ self.last_gradient_update_step_idx = 0
+
+ return total_loss
diff --git a/agents/ppo_agent.py b/agents/ppo_agent.py
new file mode 100644
index 0000000..ee1c84f
--- /dev/null
+++ b/agents/ppo_agent.py
@@ -0,0 +1,274 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.actor_critic_agent import *
+from random import shuffle
+import tensorflow as tf
+
+
+# Proximal Policy Optimization - https://arxiv.org/pdf/1707.02286.pdf
+class PPOAgent(ActorCriticAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ActorCriticAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id,
+ create_target_network=True)
+ self.critic_network = self.main_network
+
+ # define the policy network
+ tuning_parameters.agent.input_types = [InputTypes.Observation]
+ tuning_parameters.agent.output_types = [OutputTypes.PPO]
+ tuning_parameters.agent.optimizer_type = 'Adam'
+ tuning_parameters.agent.l2_regularization = 0
+ self.policy_network = NetworkWrapper(tuning_parameters, True, self.has_global, 'policy',
+ self.replicated_device, self.worker_device)
+ self.networks.append(self.policy_network)
+
+ # operations for changing the kl coefficient
+ self.kl_coefficient = tf.placeholder('float', name='kl_coefficient')
+ self.increase_kl_coefficient = tf.assign(self.policy_network.online_network.output_heads[0].kl_coefficient,
+ self.kl_coefficient * 1.5)
+ self.decrease_kl_coefficient = tf.assign(self.policy_network.online_network.output_heads[0].kl_coefficient,
+ self.kl_coefficient / 1.5)
+
+ # signals definition
+ self.value_loss = Signal('Value Loss')
+ self.signals.append(self.value_loss)
+ self.policy_loss = Signal('Policy Loss')
+ self.signals.append(self.policy_loss)
+ self.kl_divergence = Signal('KL Divergence')
+ self.signals.append(self.kl_divergence)
+ self.total_kl_divergence_during_training_process = 0.0
+ self.unclipped_grads = Signal('Grads (unclipped)')
+ self.signals.append(self.unclipped_grads)
+
+ def fill_advantages(self, batch):
+ current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch)
+
+ # * Found not to have any impact *
+ # current_states_with_timestep = self.concat_state_and_timestep(batch)
+
+ current_state_values = self.critic_network.online_network.predict([current_states]).squeeze()
+
+ # calculate advantages
+ advantages = []
+ if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
+ advantages = total_return - current_state_values
+ elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
+ # get bootstraps
+ episode_start_idx = 0
+ advantages = np.array([])
+ # current_state_values[game_overs] = 0
+ for idx, game_over in enumerate(game_overs):
+ if game_over:
+ # get advantages for the rollout
+ value_bootstrapping = np.zeros((1,))
+ rollout_state_values = np.append(current_state_values[episode_start_idx:idx+1], value_bootstrapping)
+
+ rollout_advantages, _ = \
+ self.get_general_advantage_estimation_values(rewards[episode_start_idx:idx+1],
+ rollout_state_values)
+ episode_start_idx = idx + 1
+ advantages = np.append(advantages, rollout_advantages)
+ else:
+ screen.warning("WARNING: The requested policy gradient rescaler is not available")
+
+ # standardize
+ advantages = (advantages - np.mean(advantages)) / np.std(advantages)
+
+ for transition, advantage in zip(self.memory.transitions, advantages):
+ transition.info['advantage'] = advantage
+
+ self.action_advantages.add_sample(advantages)
+
+ def train_value_network(self, dataset, epochs):
+ loss = []
+ current_states, _, _, _, _, total_return = self.extract_batch(dataset)
+
+ # * Found not to have any impact *
+ # add a timestep to the observation
+ # current_states_with_timestep = self.concat_state_and_timestep(dataset)
+
+ total_return = np.expand_dims(total_return, -1)
+ mix_fraction = self.tp.agent.value_targets_mix_fraction
+ for j in range(epochs):
+ batch_size = len(dataset)
+ if self.critic_network.online_network.optimizer_type != 'LBFGS':
+ batch_size = self.tp.batch_size
+ for i in range(len(dataset) // batch_size):
+ # split to batches for first order optimization techniques
+ current_states_batch = current_states[i * batch_size:(i + 1) * batch_size]
+ total_return_batch = total_return[i * batch_size:(i + 1) * batch_size]
+ old_policy_values = force_list(self.critic_network.target_network.predict(
+ [current_states_batch]).squeeze())
+ if self.critic_network.online_network.optimizer_type != 'LBFGS':
+ targets = total_return_batch
+ else:
+ current_values = self.critic_network.online_network.predict([current_states_batch])
+ targets = current_values * (1 - mix_fraction) + total_return_batch * mix_fraction
+
+ value_loss = self.critic_network.online_network.\
+ accumulate_gradients([current_states_batch] + old_policy_values, targets)
+ self.critic_network.apply_gradients_to_online_network()
+ if self.tp.distributed:
+ self.critic_network.apply_gradients_to_global_network()
+ self.critic_network.online_network.reset_accumulated_gradients()
+
+ loss.append([value_loss[0]])
+ loss = np.mean(loss, 0)
+ return loss
+
+ def concat_state_and_timestep(self, dataset):
+ current_states_with_timestep = [np.append(transition.state['observation'], transition.info['timestep'])
+ for transition in dataset]
+ current_states_with_timestep = np.expand_dims(current_states_with_timestep, -1)
+ return current_states_with_timestep
+
+ def train_policy_network(self, dataset, epochs):
+ loss = []
+ for j in range(epochs):
+ loss = {
+ 'total_loss': [],
+ 'policy_losses': [],
+ 'unclipped_grads': [],
+ 'fetch_result': []
+ }
+ #shuffle(dataset)
+ for i in range(len(dataset) // self.tp.batch_size):
+ batch = dataset[i * self.tp.batch_size:(i + 1) * self.tp.batch_size]
+ current_states, _, actions, _, _, total_return = self.extract_batch(batch)
+ advantages = np.array([t.info['advantage'] for t in batch])
+ if not self.tp.env_instance.discrete_controls and len(actions.shape) == 1:
+ actions = np.expand_dims(actions, -1)
+
+ # get old policy probabilities and distribution
+ old_policy = force_list(self.policy_network.target_network.predict([current_states]))
+
+ # calculate gradients and apply on both the local policy network and on the global policy network
+ fetches = [self.policy_network.online_network.output_heads[0].kl_divergence,
+ self.policy_network.online_network.output_heads[0].entropy]
+
+ total_loss, policy_losses, unclipped_grads, fetch_result =\
+ self.policy_network.online_network.accumulate_gradients(
+ [current_states, actions] + old_policy, [advantages], additional_fetches=fetches)
+
+ self.policy_network.apply_gradients_to_online_network()
+ if self.tp.distributed:
+ self.policy_network.apply_gradients_to_global_network()
+
+ self.policy_network.online_network.reset_accumulated_gradients()
+
+ loss['total_loss'].append(total_loss)
+ loss['policy_losses'].append(policy_losses)
+ loss['unclipped_grads'].append(unclipped_grads)
+ loss['fetch_result'].append(fetch_result)
+
+ self.unclipped_grads.add_sample(unclipped_grads)
+
+ for key in loss.keys():
+ loss[key] = np.mean(loss[key], 0)
+
+ if self.tp.learning_rate_decay_rate != 0:
+ curr_learning_rate = self.tp.sess.run(self.tp.learning_rate)
+ self.curr_learning_rate.add_sample(curr_learning_rate)
+ else:
+ curr_learning_rate = self.tp.learning_rate
+
+ # log training parameters
+ screen.log_dict(
+ OrderedDict([
+ ("Surrogate loss", loss['policy_losses'][0]),
+ ("KL divergence", loss['fetch_result'][0]),
+ ("Entropy", loss['fetch_result'][1]),
+ ("training epoch", j),
+ ("learning_rate", curr_learning_rate)
+ ]),
+ prefix="Policy training"
+ )
+
+ self.total_kl_divergence_during_training_process = loss['fetch_result'][0]
+ self.entropy.add_sample(loss['fetch_result'][1])
+ self.kl_divergence.add_sample(loss['fetch_result'][0])
+ return loss['total_loss']
+
+ def update_kl_coefficient(self):
+ # John Schulman takes the mean kl divergence only over the last epoch which is strange but we will follow
+ # his implementation for now because we know it works well
+ screen.log_title("KL = {}".format(self.total_kl_divergence_during_training_process))
+
+ # update kl coefficient
+ kl_target = self.tp.agent.target_kl_divergence
+ kl_coefficient = self.tp.sess.run(self.policy_network.online_network.output_heads[0].kl_coefficient)
+ if self.total_kl_divergence_during_training_process > 1.3 * kl_target:
+ # kl too high => increase regularization
+ self.tp.sess.run(self.increase_kl_coefficient, feed_dict={self.kl_coefficient: kl_coefficient})
+ elif self.total_kl_divergence_during_training_process < 0.7 * kl_target:
+ # kl too low => decrease regularization
+ self.tp.sess.run(self.decrease_kl_coefficient, feed_dict={self.kl_coefficient: kl_coefficient})
+ screen.log_title("KL penalty coefficient change = {} -> {}".format(
+ kl_coefficient, self.tp.sess.run(self.policy_network.online_network.output_heads[0].kl_coefficient)))
+
+ def post_training_commands(self):
+ if self.tp.agent.use_kl_regularization:
+ self.update_kl_coefficient()
+
+ # clean memory
+ self.memory.clean()
+
+ def train(self):
+ self.policy_network.sync()
+ self.critic_network.sync()
+
+ dataset = self.memory.transitions
+
+ self.fill_advantages(dataset)
+
+ # take only the requested number of steps
+ dataset = dataset[:self.tp.agent.num_consecutive_playing_steps]
+
+ value_loss = self.train_value_network(dataset, 1)
+ policy_loss = self.train_policy_network(dataset, 10)
+
+ self.value_loss.add_sample(value_loss)
+ self.policy_loss.add_sample(policy_loss)
+ self.update_log() # should be done in order to update the data that has been accumulated * while not playing *
+ return np.append(value_loss, policy_loss)
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ # convert to batch so we can run it through the network
+ observation = curr_state['observation']
+ observation = np.expand_dims(np.array(observation), 0)
+
+ if self.env.discrete_controls:
+ # DISCRETE
+ action_values = self.policy_network.online_network.predict(observation).squeeze()
+
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(action_values)
+ else:
+ action = np.argmax(action_values)
+ action_info = {"action_probability": action_values[action]}
+ # self.entropy.add_sample(-np.sum(action_values * np.log(action_values)))
+ else:
+ # CONTINUOUS
+ action_values_mean, action_values_std = self.policy_network.online_network.predict(observation)
+ action_values_mean = action_values_mean.squeeze()
+ action_values_std = action_values_std.squeeze()
+ if phase == RunPhase.TRAIN:
+ action = np.squeeze(np.random.randn(1, self.action_space_size) * action_values_std + action_values_mean)
+ else:
+ action = action_values_mean
+ action_info = {"action_probability": action_values_mean}
+
+ return action, action_info
diff --git a/agents/value_optimization_agent.py b/agents/value_optimization_agent.py
new file mode 100644
index 0000000..f348333
--- /dev/null
+++ b/agents/value_optimization_agent.py
@@ -0,0 +1,64 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.agent import *
+
+
+class ValueOptimizationAgent(Agent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0, create_target_network=True):
+ Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+ self.main_network = NetworkWrapper(tuning_parameters, create_target_network, self.has_global, 'main',
+ self.replicated_device, self.worker_device)
+ self.networks.append(self.main_network)
+ self.q_values = Signal("Q")
+ self.signals.append(self.q_values)
+
+ # Algorithms for which q_values are calculated from predictions will override this function
+ def get_q_values(self, prediction):
+ return prediction
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ # convert to batch so we can run it through the network
+ observation = np.expand_dims(np.array(curr_state['observation']), 0)
+ if self.tp.agent.use_measurements:
+ measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
+ prediction = self.main_network.online_network.predict([observation, measurements])
+ else:
+ prediction = self.main_network.online_network.predict(observation)
+
+ actions_q_values = self.get_q_values(prediction)
+
+ # choose action according to the exploration policy and the current phase (evaluating or training the agent)
+ if phase == RunPhase.TRAIN:
+ action = self.exploration_policy.get_action(actions_q_values)
+ else:
+ action = self.evaluation_exploration_policy.get_action(actions_q_values)
+
+ # this is for bootstrapped dqn
+ if type(actions_q_values) == list and len(actions_q_values) > 0:
+ actions_q_values = actions_q_values[self.exploration_policy.selected_head]
+ actions_q_values = actions_q_values.squeeze()
+
+ # store the q values statistics for logging
+ self.q_values.add_sample(actions_q_values)
+
+ # store information for plotting interactively (actual plotting is done in agent)
+ if self.tp.visualization.plot_action_values_online:
+ for idx, action_name in enumerate(self.env.actions_description):
+ self.episode_running_info[action_name].append(actions_q_values[idx])
+
+ action_value = {"action_value": actions_q_values[action]}
+ return action, action_value
diff --git a/architectures/__init__.py b/architectures/__init__.py
new file mode 100644
index 0000000..cbf2ac5
--- /dev/null
+++ b/architectures/__init__.py
@@ -0,0 +1,31 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from architectures.architecture import *
+from logger import failed_imports
+try:
+ from architectures.tensorflow_components.general_network import *
+ from architectures.tensorflow_components.architecture import *
+except ImportError:
+ failed_imports.append("TensorFlow")
+
+try:
+ from architectures.neon_components.general_network import *
+ from architectures.neon_components.architecture import *
+except ImportError:
+ failed_imports.append("Neon")
+
+from architectures.network_wrapper import *
\ No newline at end of file
diff --git a/architectures/architecture.py b/architectures/architecture.py
new file mode 100644
index 0000000..247b3e3
--- /dev/null
+++ b/architectures/architecture.py
@@ -0,0 +1,70 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from configurations import Preset
+
+
+class Architecture:
+ def __init__(self, tuning_parameters, name=""):
+ """
+ :param tuning_parameters: A Preset class instance with all the running paramaters
+ :type tuning_parameters: Preset
+ :param name: The name of the network
+ :param name: string
+ """
+ self.batch_size = tuning_parameters.batch_size
+ self.input_depth = tuning_parameters.env.observation_stack_size
+ self.input_height = tuning_parameters.env.desired_observation_height
+ self.input_width = tuning_parameters.env.desired_observation_width
+ self.num_actions = tuning_parameters.env.action_space_size
+ self.measurements_size = tuning_parameters.env.measurements_size \
+ if tuning_parameters.env.measurements_size else 0
+ self.learning_rate = tuning_parameters.learning_rate
+ self.optimizer = None
+ self.name = name
+ self.tp = tuning_parameters
+
+ def get_model(self, tuning_parameters):
+ """
+ :param tuning_parameters: A Preset class instance with all the running parameters
+ :type tuning_parameters: Preset
+ :return: A model
+ """
+ pass
+
+ def predict(self, inputs):
+ pass
+
+ def train_on_batch(self, inputs, targets):
+ pass
+
+ def get_weights(self):
+ pass
+
+ def set_weights(self, weights, rate=1.0):
+ pass
+
+ def reset_accumulated_gradients(self):
+ pass
+
+ def accumulate_gradients(self, inputs, targets):
+ pass
+
+ def apply_and_reset_gradients(self, gradients):
+ pass
+
+ def apply_gradients(self, gradients):
+ pass
diff --git a/architectures/neon_components/__init__.py b/architectures/neon_components/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/architectures/neon_components/architecture.py b/architectures/neon_components/architecture.py
new file mode 100644
index 0000000..de600c1
--- /dev/null
+++ b/architectures/neon_components/architecture.py
@@ -0,0 +1,129 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import copy
+from ngraph.frontends.neon import *
+import ngraph as ng
+from architectures.architecture import *
+import numpy as np
+from utils import *
+
+
+class NeonArchitecture(Architecture):
+ def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
+ Architecture.__init__(self, tuning_parameters, name)
+ assert tuning_parameters.agent.neon_support, 'Neon is not supported for this agent'
+ self.clip_error = tuning_parameters.clip_gradients
+ self.total_loss = None
+ self.epoch = 0
+ self.inputs = []
+ self.outputs = []
+ self.targets = []
+ self.losses = []
+
+ self.transformer = tuning_parameters.sess
+ self.network = self.get_model(tuning_parameters)
+ self.accumulated_gradients = []
+
+ # training and inference ops
+ train_output = ng.sequential([
+ self.optimizer(self.total_loss),
+ self.total_loss
+ ])
+ placeholders = self.inputs + self.targets
+ self.train_op = self.transformer.add_computation(
+ ng.computation(
+ train_output, *placeholders
+ )
+ )
+ self.predict_op = self.transformer.add_computation(
+ ng.computation(
+ self.outputs, self.inputs[0]
+ )
+ )
+
+ # update weights from array op
+ self.weights = [ng.placeholder(w.axes) for w in self.total_loss.variables()]
+ self.set_weights_ops = []
+ for target_variable, variable in zip(self.total_loss.variables(), self.weights):
+ self.set_weights_ops.append(self.transformer.add_computation(
+ ng.computation(
+ ng.assign(target_variable, variable), variable
+ )
+ ))
+
+ # get weights op
+ self.get_variables = self.transformer.add_computation(
+ ng.computation(
+ self.total_loss.variables()
+ )
+ )
+
+ def predict(self, inputs):
+ batch_size = inputs.shape[0]
+
+ # move batch axis to the end
+ inputs = inputs.swapaxes(0, -1)
+ prediction = self.predict_op(inputs) # TODO: problem with multiple inputs
+
+ if type(prediction) != tuple:
+ prediction = (prediction)
+
+ # process all the outputs from the network
+ output = []
+ for p in prediction:
+ output.append(p.transpose()[:batch_size].copy())
+
+ # if there is only one output then we don't need a list
+ if len(output) == 1:
+ output = output[0]
+ return output
+
+ def train_on_batch(self, inputs, targets):
+ loss = self.accumulate_gradients(inputs, targets)
+ self.apply_and_reset_gradients(self.accumulated_gradients)
+ return loss
+
+ def get_weights(self):
+ return self.get_variables()
+
+ def set_weights(self, weights, rate=1.0):
+ if rate != 1:
+ current_weights = self.get_weights()
+ updated_weights = [(1 - rate) * t + rate * o for t, o in zip(current_weights, weights)]
+ else:
+ updated_weights = weights
+ for update_function, variable in zip(self.set_weights_ops, updated_weights):
+ update_function(variable)
+
+ def accumulate_gradients(self, inputs, targets):
+ # Neon doesn't currently allow separating the grads calculation and grad apply operations
+ # so this feature is not currently available. instead we do a full training iteration
+ inputs = force_list(inputs)
+ targets = force_list(targets)
+
+ for idx, input in enumerate(inputs):
+ inputs[idx] = input.swapaxes(0, -1)
+
+ for idx, target in enumerate(targets):
+ targets[idx] = np.rollaxis(target, 0, len(target.shape))
+
+ all_inputs = inputs + targets
+
+ loss = np.mean(self.train_op(*all_inputs))
+
+ return [loss]
diff --git a/architectures/neon_components/embedders.py b/architectures/neon_components/embedders.py
new file mode 100644
index 0000000..ccfd772
--- /dev/null
+++ b/architectures/neon_components/embedders.py
@@ -0,0 +1,88 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import ngraph.frontends.neon as neon
+import ngraph as ng
+from ngraph.util.names import name_scope
+
+
+class InputEmbedder:
+ def __init__(self, input_size, batch_size=None, activation_function=neon.Rectlin(), name="embedder"):
+ self.name = name
+ self.input_size = input_size
+ self.batch_size = batch_size
+ self.activation_function = activation_function
+ self.weights_init = neon.GlorotInit()
+ self.biases_init = neon.ConstantInit()
+ self.input = None
+ self.output = None
+
+ def __call__(self, prev_input_placeholder=None):
+ with name_scope(self.get_name()):
+ # create the input axes
+ axes = []
+ if len(self.input_size) == 2:
+ axis_names = ['H', 'W']
+ else:
+ axis_names = ['C', 'H', 'W']
+ for axis_size, axis_name in zip(self.input_size, axis_names):
+ axes.append(ng.make_axis(axis_size, name=axis_name))
+ batch_axis_full = ng.make_axis(self.batch_size, name='N')
+ input_axes = ng.make_axes(axes)
+
+ if prev_input_placeholder is None:
+ self.input = ng.placeholder(input_axes + [batch_axis_full])
+ else:
+ self.input = prev_input_placeholder
+ self._build_module()
+
+ return self.input, self.output(self.input)
+
+ def _build_module(self):
+ pass
+
+ def get_name(self):
+ return self.name
+
+
+class ImageEmbedder(InputEmbedder):
+ def __init__(self, input_size, batch_size=None, input_rescaler=255.0, activation_function=neon.Rectlin(), name="embedder"):
+ InputEmbedder.__init__(self, input_size, batch_size, activation_function, name)
+ self.input_rescaler = input_rescaler
+
+ def _build_module(self):
+ # image observation
+ self.output = neon.Sequential([
+ neon.Preprocess(functor=lambda x: x / self.input_rescaler),
+ neon.Convolution((8, 8, 32), strides=4, activation=self.activation_function,
+ filter_init=self.weights_init, bias_init=self.biases_init),
+ neon.Convolution((4, 4, 64), strides=2, activation=self.activation_function,
+ filter_init=self.weights_init, bias_init=self.biases_init),
+ neon.Convolution((3, 3, 64), strides=1, activation=self.activation_function,
+ filter_init=self.weights_init, bias_init=self.biases_init)
+ ])
+
+
+class VectorEmbedder(InputEmbedder):
+ def __init__(self, input_size, batch_size=None, activation_function=neon.Rectlin(), name="embedder"):
+ InputEmbedder.__init__(self, input_size, batch_size, activation_function, name)
+
+ def _build_module(self):
+ # vector observation
+ self.output = neon.Sequential([
+ neon.Affine(nout=256, activation=self.activation_function,
+ weight_init=self.weights_init, bias_init=self.biases_init)
+ ])
diff --git a/architectures/neon_components/general_network.py b/architectures/neon_components/general_network.py
new file mode 100644
index 0000000..4bae454
--- /dev/null
+++ b/architectures/neon_components/general_network.py
@@ -0,0 +1,191 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from architectures.neon_components.embedders import *
+from architectures.neon_components.heads import *
+from architectures.neon_components.middleware import *
+from architectures.neon_components.architecture import *
+from configurations import InputTypes, OutputTypes, MiddlewareTypes
+
+
+class GeneralNeonNetwork(NeonArchitecture):
+ def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
+ self.global_network = global_network
+ self.network_is_local = network_is_local
+ self.num_heads_per_network = 1 if tuning_parameters.agent.use_separate_networks_per_head else \
+ len(tuning_parameters.agent.output_types)
+ self.num_networks = 1 if not tuning_parameters.agent.use_separate_networks_per_head else \
+ len(tuning_parameters.agent.output_types)
+ self.input_embedders = []
+ self.output_heads = []
+ self.activation_function = self.get_activation_function(
+ tuning_parameters.agent.hidden_layers_activation_function)
+
+ NeonArchitecture.__init__(self, tuning_parameters, name, global_network, network_is_local)
+
+ def get_activation_function(self, activation_function_string):
+ activation_functions = {
+ 'relu': neon.Rectlin(),
+ 'tanh': neon.Tanh(),
+ 'sigmoid': neon.Logistic(),
+ 'elu': neon.Explin(),
+ 'none': None
+ }
+ assert activation_function_string in activation_functions.keys(), \
+ "Activation function must be one of the following {}".format(activation_functions.keys())
+ return activation_functions[activation_function_string]
+
+ def get_input_embedder(self, embedder_type):
+ # the observation can be either an image or a vector
+ def get_observation_embedding(with_timestep=False):
+ if self.input_height > 1:
+ return ImageEmbedder((self.input_depth, self.input_height, self.input_width), self.batch_size,
+ name="observation")
+ else:
+ return VectorEmbedder((self.input_depth, self.input_width + int(with_timestep)), self.batch_size,
+ name="observation")
+
+ input_mapping = {
+ InputTypes.Observation: get_observation_embedding(),
+ InputTypes.Measurements: VectorEmbedder(self.measurements_size, self.batch_size, name="measurements"),
+ InputTypes.GoalVector: VectorEmbedder(self.measurements_size, self.batch_size, name="goal_vector"),
+ InputTypes.Action: VectorEmbedder((self.num_actions,), self.batch_size, name="action"),
+ InputTypes.TimedObservation: get_observation_embedding(with_timestep=True),
+ }
+ return input_mapping[embedder_type]
+
+ def get_middleware_embedder(self, middleware_type):
+ return {MiddlewareTypes.LSTM: None, # LSTM over Neon is currently not supported in Coach
+ MiddlewareTypes.FC: FC_Embedder}.get(middleware_type)(self.activation_function)
+
+ def get_output_head(self, head_type, head_idx, loss_weight=1.):
+ output_mapping = {
+ OutputTypes.Q: QHead,
+ OutputTypes.DuelingQ: DuelingQHead,
+ OutputTypes.V: None, # Policy Optimization algorithms over Neon are currently not supported in Coach
+ OutputTypes.Pi: None, # Policy Optimization algorithms over Neon are currently not supported in Coach
+ OutputTypes.MeasurementsPrediction: None, # DFP over Neon is currently not supported in Coach
+ OutputTypes.DNDQ: None, # NEC over Neon is currently not supported in Coach
+ OutputTypes.NAF: None, # NAF over Neon is currently not supported in Coach
+ OutputTypes.PPO: None, # PPO over Neon is currently not supported in Coach
+ OutputTypes.PPO_V: None # PPO over Neon is currently not supported in Coach
+ }
+ return output_mapping[head_type](self.tp, head_idx, loss_weight, self.network_is_local)
+
+ def get_model(self, tuning_parameters):
+ """
+ :param tuning_parameters: A Preset class instance with all the running paramaters
+ :type tuning_parameters: Preset
+ :return: A model
+ """
+ assert len(self.tp.agent.input_types) > 0, "At least one input type should be defined"
+ assert len(self.tp.agent.output_types) > 0, "At least one output type should be defined"
+ assert self.tp.agent.middleware_type is not None, "Exactly one middleware type should be defined"
+ assert len(self.tp.agent.loss_weights) > 0, "At least one loss weight should be defined"
+ assert len(self.tp.agent.output_types) == len(self.tp.agent.loss_weights), \
+ "Number of loss weights should match the number of output types"
+ local_network_in_distributed_training = self.global_network is not None and self.network_is_local
+
+ tuning_parameters.activation_function = self.activation_function
+ done_creating_input_placeholders = False
+
+ for network_idx in range(self.num_networks):
+ with name_scope('network_{}'.format(network_idx)):
+ ####################
+ # Input Embeddings #
+ ####################
+
+ state_embedding = []
+ for idx, input_type in enumerate(self.tp.agent.input_types):
+ # get the class of the input embedder
+ self.input_embedders.append(self.get_input_embedder(input_type))
+
+ # in the case each head uses a different network, we still reuse the input placeholders
+ prev_network_input_placeholder = self.inputs[idx] if done_creating_input_placeholders else None
+
+ # create the input embedder instance and store the input placeholder and the embedding
+ input_placeholder, embedding = self.input_embedders[-1](prev_network_input_placeholder)
+ if len(self.inputs) < len(self.tp.agent.input_types):
+ self.inputs.append(input_placeholder)
+ state_embedding.append(embedding)
+
+ done_creating_input_placeholders = True
+
+ ##############
+ # Middleware #
+ ##############
+
+ state_embedding = ng.concat_along_axis(state_embedding, state_embedding[0].axes[0]) \
+ if len(state_embedding) > 1 else state_embedding[0]
+ self.middleware_embedder = self.get_middleware_embedder(self.tp.agent.middleware_type)
+ _, self.state_embedding = self.middleware_embedder(state_embedding)
+
+ ################
+ # Output Heads #
+ ################
+
+ for head_idx in range(self.num_heads_per_network):
+ for head_copy_idx in range(self.tp.agent.num_output_head_copies):
+ if self.tp.agent.use_separate_networks_per_head:
+ # if we use separate networks per head, then the head type corresponds top the network idx
+ head_type_idx = network_idx
+ else:
+ # if we use a single network with multiple heads, then the head type is the current head idx
+ head_type_idx = head_idx
+ self.output_heads.append(self.get_output_head(self.tp.agent.output_types[head_type_idx],
+ head_copy_idx,
+ self.tp.agent.loss_weights[head_type_idx]))
+ if self.network_is_local:
+ output, target_placeholder, input_placeholder = self.output_heads[-1](self.state_embedding)
+ self.targets.extend(target_placeholder)
+ else:
+ output, input_placeholder = self.output_heads[-1](self.state_embedding)
+
+ self.outputs.extend(output)
+ self.inputs.extend(input_placeholder)
+
+ # Losses
+ self.losses = []
+ for output_head in self.output_heads:
+ self.losses += output_head.loss
+ self.total_loss = sum(self.losses)
+
+ # Learning rate
+ if self.tp.learning_rate_decay_rate != 0:
+ raise Exception("learning rate decay is not supported in neon")
+
+ # Optimizer
+ if local_network_in_distributed_training and \
+ hasattr(self.tp.agent, "shared_optimizer") and self.tp.agent.shared_optimizer:
+ # distributed training and this is the local network instantiation
+ self.optimizer = self.global_network.optimizer
+ else:
+ if tuning_parameters.agent.optimizer_type == 'Adam':
+ self.optimizer = neon.Adam(
+ learning_rate=tuning_parameters.learning_rate,
+ gradient_clip_norm=tuning_parameters.clip_gradients
+ )
+ elif tuning_parameters.agent.optimizer_type == 'RMSProp':
+ self.optimizer = neon.RMSProp(
+ learning_rate=tuning_parameters.learning_rate,
+ gradient_clip_norm=tuning_parameters.clip_gradients,
+ decay_rate=0.9,
+ epsilon=0.01
+ )
+ elif tuning_parameters.agent.optimizer_type == 'LBFGS':
+ raise Exception("LBFGS optimizer is not supported in neon")
+ else:
+ raise Exception("{} is not a valid optimizer type".format(tuning_parameters.agent.optimizer_type))
diff --git a/architectures/neon_components/heads.py b/architectures/neon_components/heads.py
new file mode 100644
index 0000000..41f302a
--- /dev/null
+++ b/architectures/neon_components/heads.py
@@ -0,0 +1,194 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import ngraph as ng
+from ngraph.util.names import name_scope
+import ngraph.frontends.neon as neon
+import numpy as np
+from utils import force_list
+from architectures.neon_components.losses import *
+
+
+class Head:
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ self.head_idx = head_idx
+ self.name = "head"
+ self.output = []
+ self.loss = []
+ self.loss_type = []
+ self.regularizations = []
+ self.loss_weight = force_list(loss_weight)
+ self.weights_init = neon.GlorotInit()
+ self.biases_init = neon.ConstantInit()
+ self.target = []
+ self.input = []
+ self.is_local = is_local
+ self.batch_size = tuning_parameters.batch_size
+
+ def __call__(self, input_layer):
+ """
+ Wrapper for building the module graph including scoping and loss creation
+ :param input_layer: the input to the graph
+ :return: the output of the last layer and the target placeholder
+ """
+ with name_scope(self.get_name()):
+ self._build_module(input_layer)
+
+ self.output = force_list(self.output)
+ self.target = force_list(self.target)
+ self.input = force_list(self.input)
+ self.loss_type = force_list(self.loss_type)
+ self.loss = force_list(self.loss)
+ self.regularizations = force_list(self.regularizations)
+ if self.is_local:
+ self.set_loss()
+
+ if self.is_local:
+ return self.output, self.target, self.input
+ else:
+ return self.output, self.input
+
+ def _build_module(self, input_layer):
+ """
+ Builds the graph of the module
+ :param input_layer: the input to the graph
+ :return: None
+ """
+ pass
+
+ def get_name(self):
+ """
+ Get a formatted name for the module
+ :return: the formatted name
+ """
+ return '{}_{}'.format(self.name, self.head_idx)
+
+ def set_loss(self):
+ """
+ Creates a target placeholder and loss function for each loss_type and regularization
+ :param loss_type: a tensorflow loss function
+ :param scope: the name scope to include the tensors in
+ :return: None
+ """
+ # add losses and target placeholder
+ for idx in range(len(self.loss_type)):
+ # output_axis = ng.make_axis(self.num_actions, name='q_values')
+ batch_axis_full = ng.make_axis(self.batch_size, name='N')
+ target = ng.placeholder(ng.make_axes([self.output[0].axes[0], batch_axis_full]))
+ self.target.append(target)
+ loss = self.loss_type[idx](self.target[-1], self.output[idx],
+ weights=self.loss_weight[idx], scope=self.get_name())
+ self.loss.append(loss)
+
+ # add regularizations
+ for regularization in self.regularizations:
+ self.loss.append(regularization)
+
+
+class QHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'q_values_head'
+ self.num_actions = tuning_parameters.env_instance.action_space_size
+ if tuning_parameters.agent.replace_mse_with_huber_loss:
+ raise Exception("huber loss is not supported in neon")
+ else:
+ self.loss_type = mean_squared_error
+
+ def _build_module(self, input_layer):
+ # Standard Q Network
+ self.output = neon.Sequential([
+ neon.Affine(nout=self.num_actions,
+ weight_init=self.weights_init, bias_init=self.biases_init)
+ ])(input_layer)
+
+
+class DuelingQHead(QHead):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ QHead.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+
+ def _build_module(self, input_layer):
+ # Dueling Network
+ # state value tower - V
+ output_axis = ng.make_axis(self.num_actions, name='q_values')
+
+ state_value = neon.Sequential([
+ neon.Affine(nout=256, activation=neon.Rectlin(),
+ weight_init=self.weights_init, bias_init=self.biases_init),
+ neon.Affine(nout=1,
+ weight_init=self.weights_init, bias_init=self.biases_init)
+ ])(input_layer)
+
+ # action advantage tower - A
+ action_advantage_unnormalized = neon.Sequential([
+ neon.Affine(nout=256, activation=neon.Rectlin(),
+ weight_init=self.weights_init, bias_init=self.biases_init),
+ neon.Affine(axes=output_axis,
+ weight_init=self.weights_init, bias_init=self.biases_init)
+ ])(input_layer)
+ action_advantage = action_advantage_unnormalized - ng.mean(action_advantage_unnormalized)
+
+ repeated_state_value = ng.expand_dims(ng.slice_along_axis(state_value, state_value.axes[0], 0), output_axis, 0)
+
+ # merge to state-action value function Q
+ self.output = repeated_state_value + action_advantage
+
+
+class MeasurementsPredictionHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'future_measurements_head'
+ self.num_actions = tuning_parameters.env_instance.action_space_size
+ self.num_measurements = tuning_parameters.env.measurements_size[0] \
+ if tuning_parameters.env.measurements_size else 0
+ self.num_prediction_steps = tuning_parameters.agent.num_predicted_steps_ahead
+ self.multi_step_measurements_size = self.num_measurements * self.num_prediction_steps
+ if tuning_parameters.agent.replace_mse_with_huber_loss:
+ raise Exception("huber loss is not supported in neon")
+ else:
+ self.loss_type = mean_squared_error
+
+ def _build_module(self, input_layer):
+ # This is almost exactly the same as Dueling Network but we predict the future measurements for each action
+
+ multistep_measurements_size = self.measurements_size[0] * self.num_predicted_steps_ahead
+
+ # actions expectation tower (expectation stream) - E
+ with name_scope("expectation_stream"):
+ expectation_stream = neon.Sequential([
+ neon.Affine(nout=256, activation=neon.Rectlin(),
+ weight_init=self.weights_init, bias_init=self.biases_init),
+ neon.Affine(nout=multistep_measurements_size,
+ weight_init=self.weights_init, bias_init=self.biases_init)
+ ])(input_layer)
+
+ # action fine differences tower (action stream) - A
+ with name_scope("action_stream"):
+ action_stream_unnormalized = neon.Sequential([
+ neon.Affine(nout=256, activation=neon.Rectlin(),
+ weight_init=self.weights_init, bias_init=self.biases_init),
+ neon.Affine(nout=self.num_actions * multistep_measurements_size,
+ weight_init=self.weights_init, bias_init=self.biases_init),
+ neon.Reshape((self.num_actions, multistep_measurements_size))
+ ])(input_layer)
+ action_stream = action_stream_unnormalized - ng.mean(action_stream_unnormalized)
+
+ repeated_expectation_stream = ng.slice_along_axis(expectation_stream, expectation_stream.axes[0], 0)
+ repeated_expectation_stream = ng.expand_dims(repeated_expectation_stream, output_axis, 0)
+
+ # merge to future measurements predictions
+ self.output = repeated_expectation_stream + action_stream
+
diff --git a/architectures/neon_components/losses.py b/architectures/neon_components/losses.py
new file mode 100644
index 0000000..26e8644
--- /dev/null
+++ b/architectures/neon_components/losses.py
@@ -0,0 +1,28 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import ngraph as ng
+import ngraph.frontends.neon as neon
+from ngraph.util.names import name_scope
+import numpy as np
+
+
+def mean_squared_error(targets, outputs, weights=1.0, scope=""):
+ with name_scope(scope):
+ # TODO: reduce mean over the action axis
+ loss = ng.squared_L2(targets - outputs)
+ weighted_loss = loss * weights
+ return weighted_loss
diff --git a/architectures/neon_components/middleware.py b/architectures/neon_components/middleware.py
new file mode 100644
index 0000000..4ace29e
--- /dev/null
+++ b/architectures/neon_components/middleware.py
@@ -0,0 +1,50 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import ngraph as ng
+import ngraph.frontends.neon as neon
+from ngraph.util.names import name_scope
+import numpy as np
+
+
+class MiddlewareEmbedder:
+ def __init__(self, activation_function=neon.Rectlin(), name="middleware_embedder"):
+ self.name = name
+ self.input = None
+ self.output = None
+ self.weights_init = neon.GlorotInit()
+ self.biases_init = neon.ConstantInit()
+ self.activation_function = activation_function
+
+ def __call__(self, input_layer):
+ with name_scope(self.get_name()):
+ self.input = input_layer
+ self._build_module()
+
+ return self.input, self.output(self.input)
+
+ def _build_module(self):
+ pass
+
+ def get_name(self):
+ return self.name
+
+
+class FC_Embedder(MiddlewareEmbedder):
+ def _build_module(self):
+ self.output = neon.Sequential([
+ neon.Affine(nout=512, activation=self.activation_function,
+ weight_init=self.weights_init, bias_init=self.biases_init)])
diff --git a/architectures/network_wrapper.py b/architectures/network_wrapper.py
new file mode 100644
index 0000000..f35e0aa
--- /dev/null
+++ b/architectures/network_wrapper.py
@@ -0,0 +1,179 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import OrderedDict
+from configurations import Preset, Frameworks
+from logger import *
+try:
+ import tensorflow as tf
+ from architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
+except ImportError:
+ failed_imports.append("TensorFlow")
+
+try:
+ from architectures.neon_components.general_network import GeneralNeonNetwork
+except ImportError:
+ failed_imports.append("Neon")
+
+
+class NetworkWrapper:
+ def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
+ """
+
+ :param tuning_parameters:
+ :type tuning_parameters: Preset
+ :param has_target:
+ :param has_global:
+ :param name:
+ :param replicated_device:
+ :param worker_device:
+ """
+ self.tp = tuning_parameters
+ self.has_target = has_target
+ self.has_global = has_global
+ self.name = name
+ self.sess = tuning_parameters.sess
+
+ if self.tp.framework == Frameworks.TensorFlow:
+ general_network = GeneralTensorFlowNetwork
+ elif self.tp.framework == Frameworks.Neon:
+ general_network = GeneralNeonNetwork
+ else:
+ raise Exception("{} Framework is not supported".format(Frameworks().to_string(self.tp.framework)))
+
+ # Global network - the main network shared between threads
+ self.global_network = None
+ if self.has_global:
+ with tf.device(replicated_device):
+ self.global_network = general_network(tuning_parameters, '{}/global'.format(name),
+ network_is_local=False)
+
+ # Online network - local copy of the main network used for playing
+ self.online_network = None
+ with tf.device(worker_device):
+ self.online_network = general_network(tuning_parameters, '{}/online'.format(name),
+ self.global_network, network_is_local=True)
+
+ # Target network - a local, slow updating network used for stabilizing the learning
+ self.target_network = None
+ if self.has_target:
+ with tf.device(worker_device):
+ self.target_network = general_network(tuning_parameters, '{}/target'.format(name),
+ network_is_local=True)
+
+ if not self.tp.distributed and self.tp.framework == Frameworks.TensorFlow:
+ self.model_saver = tf.train.Saver()
+ if self.tp.sess and self.tp.checkpoint_restore_dir:
+ checkpoint = tf.train.latest_checkpoint(self.tp.checkpoint_restore_dir)
+ screen.log_title("Loading checkpoint: {}".format(checkpoint))
+ self.model_saver.restore(self.tp.sess, checkpoint)
+
+ def sync(self):
+ """
+ Initializes the weights of the networks to match each other
+ :return:
+ """
+ self.update_online_network()
+ self.update_target_network()
+
+ def update_target_network(self, rate=1.0):
+ """
+ Copy weights: online network >>> target network
+ :param rate: the rate of copying the weights - 1 for copying exactly
+ """
+ if self.target_network:
+ self.target_network.set_weights(self.online_network.get_weights(), rate)
+
+ def update_online_network(self, rate=1.0):
+ """
+ Copy weights: global network >>> online network
+ :param rate: the rate of copying the weights - 1 for copying exactly
+ """
+ if self.global_network:
+ self.online_network.set_weights(self.global_network.get_weights(), rate)
+
+ def apply_gradients_to_global_network(self):
+ """
+ Apply gradients from the online network on the global network
+ :return:
+ """
+ self.global_network.apply_gradients(self.online_network.accumulated_gradients)
+
+ def apply_gradients_to_online_network(self):
+ """
+ Apply gradients from the online network on itself
+ :return:
+ """
+ self.online_network.apply_gradients(self.online_network.accumulated_gradients)
+
+ def train_and_sync_networks(self, inputs, targets):
+ """
+ A generic training function that enables multi-threading training using a global network if necessary.
+ :param inputs: The inputs for the network.
+ :param targets: The targets corresponding to the given inputs
+ :return: The loss of the training iteration
+ """
+ result = self.online_network.accumulate_gradients(inputs, targets)
+ self.apply_gradients_and_sync_networks()
+ return result
+
+ def apply_gradients_and_sync_networks(self):
+ """
+ Applies the gradients accumulated in the online network to the global network or to itself and syncs the
+ networks if necessary
+ """
+ if self.global_network:
+ self.apply_gradients_to_global_network()
+ self.online_network.reset_accumulated_gradients()
+ self.update_online_network()
+ else:
+ self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients)
+
+ def get_local_variables(self):
+ """
+ Get all the variables that are local to the thread
+ :return: a list of all the variables that are local to the thread
+ """
+ local_variables = [v for v in tf.global_variables() if self.online_network.name in v.name]
+ if self.has_target:
+ local_variables += [v for v in tf.global_variables() if self.target_network.name in v.name]
+ return local_variables
+
+ def get_global_variables(self):
+ """
+ Get all the variables that are shared between threads
+ :return: a list of all the variables that are shared between threads
+ """
+ global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
+ return global_variables
+
+ def set_session(self, sess):
+ self.sess = sess
+ self.online_network.sess = sess
+ if self.global_network:
+ self.global_network.sess = sess
+ if self.target_network:
+ self.target_network.sess = sess
+
+ def save_model(self, model_id):
+ saved_model_path = self.model_saver.save(self.tp.sess, os.path.join(self.tp.save_model_dir,
+ str(model_id) + '.ckpt'))
+ screen.log_dict(
+ OrderedDict([
+ ("Saving model", saved_model_path),
+ ]),
+ prefix="Checkpoint"
+ )
diff --git a/architectures/tensorflow_components/__init__.py b/architectures/tensorflow_components/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/architectures/tensorflow_components/architecture.py b/architectures/tensorflow_components/architecture.py
new file mode 100644
index 0000000..6ae0241
--- /dev/null
+++ b/architectures/tensorflow_components/architecture.py
@@ -0,0 +1,290 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from architectures.architecture import Architecture
+import tensorflow as tf
+from utils import force_list, squeeze_list
+from configurations import Preset, MiddlewareTypes
+import numpy as np
+import time
+
+
+class TensorFlowArchitecture(Architecture):
+ def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
+ """
+ :param tuning_parameters: The parameters used for running the algorithm
+ :type tuning_parameters: Preset
+ :param name: The name of the network
+ """
+ Architecture.__init__(self, tuning_parameters, name)
+ self.middleware_embedder = None
+ self.network_is_local = network_is_local
+ assert tuning_parameters.agent.tensorflow_support, 'TensorFlow is not supported for this agent'
+ self.sess = tuning_parameters.sess
+ self.inputs = []
+ self.outputs = []
+ self.targets = []
+ self.losses = []
+ self.total_loss = None
+ self.trainable_weights = []
+ self.weights_placeholders = []
+ self.curr_rnn_c_in = None
+ self.curr_rnn_h_in = None
+ self.gradients_wrt_inputs = []
+
+ self.optimizer_type = self.tp.agent.optimizer_type
+ if self.tp.seed is not None:
+ tf.set_random_seed(self.tp.seed)
+ with tf.variable_scope(self.name, initializer=tf.contrib.layers.xavier_initializer()):
+ self.global_step = tf.contrib.framework.get_or_create_global_step()
+
+ # build the network
+ self.get_model(tuning_parameters)
+
+ # model weights
+ self.trainable_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
+
+ # locks for synchronous training
+ if self.tp.distributed and not self.tp.agent.async_training and not self.network_is_local:
+ self.lock_counter = tf.get_variable("lock_counter", [], tf.int32,
+ initializer=tf.constant_initializer(0, dtype=tf.int32),
+ trainable=False)
+ self.lock = self.lock_counter.assign_add(1, use_locking=True)
+ self.lock_init = self.lock_counter.assign(0)
+
+ self.release_counter = tf.get_variable("release_counter", [], tf.int32,
+ initializer=tf.constant_initializer(0, dtype=tf.int32),
+ trainable=False)
+ self.release = self.release_counter.assign_add(1, use_locking=True)
+ self.release_init = self.release_counter.assign(0)
+
+ # local network does the optimization so we need to create all the ops we are going to use to optimize
+ for idx, var in enumerate(self.trainable_weights):
+ placeholder = tf.placeholder(tf.float32, shape=var.get_shape(), name=str(idx) + '_holder')
+ self.weights_placeholders.append(placeholder)
+ self.update_weights_from_list = [weights.assign(holder) for holder, weights in
+ zip(self.weights_placeholders, self.trainable_weights)]
+
+ # gradients ops
+ self.tensor_gradients = tf.gradients(self.total_loss, self.trainable_weights)
+ self.gradients_norm = tf.global_norm(self.tensor_gradients)
+ if self.tp.clip_gradients is not None and self.tp.clip_gradients != 0:
+ self.clipped_grads, self.grad_norms = tf.clip_by_global_norm(self.tensor_gradients,
+ tuning_parameters.clip_gradients)
+
+ # gradients of the outputs w.r.t. the inputs
+ if len(self.outputs) == 1:
+ self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs]
+ self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights')
+ self.weighted_gradients = tf.gradients(self.outputs[0], self.trainable_weights, self.gradients_weights_ph)
+
+ # L2 regularization
+ if self.tp.agent.l2_regularization != 0:
+ self.l2_regularization = [tf.add_n([tf.nn.l2_loss(v) for v in self.trainable_weights])
+ * self.tp.agent.l2_regularization]
+ tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.l2_regularization)
+
+ self.inc_step = self.global_step.assign_add(1)
+
+ # defining the optimization process (for LBFGS we have less control over the optimizer)
+ if self.optimizer_type != 'LBFGS':
+ # no global network, this is a plain simple centralized training
+ self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
+ zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step)
+
+ # initialize or restore model
+ if not self.tp.distributed:
+ self.init_op = tf.global_variables_initializer()
+
+ if self.sess:
+ self.sess.run(self.init_op)
+
+ self.accumulated_gradients = None
+
+ def reset_accumulated_gradients(self):
+ """
+ Reset the gradients accumulation placeholder
+ """
+ if self.accumulated_gradients is None:
+ self.accumulated_gradients = self.tp.sess.run(self.trainable_weights)
+
+ for ix, grad in enumerate(self.accumulated_gradients):
+ self.accumulated_gradients[ix] = grad * 0
+
+ def accumulate_gradients(self, inputs, targets, additional_fetches=None):
+ """
+ Runs a forward pass & backward pass, clips gradients if needed and accumulates them into the accumulation
+ placeholders
+ :param additional_fetches: Optional tensors to fetch during gradients calculation
+ :param inputs: The input batch for the network
+ :param targets: The targets corresponding to the input batch
+ :return: A list containing the total loss and the individual network heads losses
+ """
+
+ if self.accumulated_gradients is None:
+ self.reset_accumulated_gradients()
+
+ # feed inputs
+ if additional_fetches is None:
+ additional_fetches = []
+ inputs = force_list(inputs)
+
+ feed_dict = dict(zip(self.inputs, inputs))
+
+ # feed targets
+ targets = force_list(targets)
+ for placeholder_idx, target in enumerate(targets):
+ feed_dict[self.targets[placeholder_idx]] = target
+
+ if self.optimizer_type != 'LBFGS':
+ # set the fetches
+ fetches = [self.gradients_norm]
+ if self.tp.clip_gradients:
+ fetches.append(self.clipped_grads)
+ else:
+ fetches.append(self.tensor_gradients)
+ fetches += [self.total_loss, self.losses]
+ if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
+ fetches.append(self.middleware_embedder.state_out)
+ additional_fetches_start_idx = len(fetches)
+ fetches += additional_fetches
+
+ # feed the lstm state if necessary
+ if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
+ feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
+ feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
+
+ # get grads
+ result = self.tp.sess.run(fetches, feed_dict=feed_dict)
+
+ # extract the fetches
+ norm_unclipped_grads, grads, total_loss, losses = result[:4]
+ if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
+ (self.curr_rnn_c_in, self.curr_rnn_h_in) = result[4]
+ fetched_tensors = []
+ if len(additional_fetches) > 0:
+ fetched_tensors = result[additional_fetches_start_idx:]
+
+ # accumulate the gradients
+ for idx, grad in enumerate(grads):
+ self.accumulated_gradients[idx] += grad
+
+ return total_loss, losses, norm_unclipped_grads, fetched_tensors
+
+ else:
+ self.optimizer.minimize(session=self.tp.sess, feed_dict=feed_dict)
+
+ return [0]
+
+ def apply_and_reset_gradients(self, gradients, scaler=1.):
+ """
+ Applies the given gradients to the network weights and resets the accumulation placeholder
+ :param gradients: The gradients to use for the update
+ :param scaler: A scaling factor that allows rescaling the gradients before applying them
+ """
+ self.apply_gradients(gradients, scaler)
+ self.reset_accumulated_gradients()
+
+ def apply_gradients(self, gradients, scaler=1.):
+ """
+ Applies the given gradients to the network weights
+ :param gradients: The gradients to use for the update
+ :param scaler: A scaling factor that allows rescaling the gradients before applying them
+ """
+ if self.tp.agent.async_training or not self.tp.distributed:
+ if hasattr(self, 'global_step') and not self.network_is_local:
+ self.tp.sess.run(self.inc_step)
+
+ if self.optimizer_type != 'LBFGS':
+
+ # lock barrier
+ if hasattr(self, 'lock_counter'):
+ self.tp.sess.run(self.lock)
+ while self.tp.sess.run(self.lock_counter) % self.tp.num_threads != 0:
+ time.sleep(0.00001)
+ # rescale the gradients so that they average out with the gradients from the other workers
+ scaler /= float(self.tp.num_threads)
+
+ # apply gradients
+ if scaler != 1.:
+ for gradient in gradients:
+ gradient /= scaler
+ feed_dict = dict(zip(self.weights_placeholders, gradients))
+ _ = self.tp.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict)
+
+ # release barrier
+ if hasattr(self, 'release_counter'):
+ self.tp.sess.run(self.release)
+ while self.tp.sess.run(self.release_counter) % self.tp.num_threads != 0:
+ time.sleep(0.00001)
+
+ def predict(self, inputs):
+ """
+ Run a forward pass of the network using the given input
+ :param inputs: The input for the network
+ :return: The network output
+ """
+
+ feed_dict = dict(zip(self.inputs, force_list(inputs)))
+ if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
+ feed_dict[self.middleware_embedder.c_in] = self.curr_rnn_c_in
+ feed_dict[self.middleware_embedder.h_in] = self.curr_rnn_h_in
+ output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.tp.sess.run([self.outputs, self.middleware_embedder.state_out], feed_dict=feed_dict)
+ else:
+ output = self.tp.sess.run(self.outputs, feed_dict)
+
+ return squeeze_list(output)
+
+ def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None):
+ """
+ Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients
+ :param additional_fetches: Optional tensors to fetch during the training process
+ :param inputs: The input for the network
+ :param targets: The targets corresponding to the input batch
+ :param scaler: A scaling factor that allows rescaling the gradients before applying them
+ :return: The loss of the network
+ """
+ if additional_fetches is None:
+ additional_fetches = []
+ force_list(additional_fetches)
+ loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches)
+ self.apply_and_reset_gradients(self.accumulated_gradients, scaler)
+ return loss
+
+ def get_weights(self):
+ """
+ :return: a list of tensors containing the network weights for each layer
+ """
+ return self.trainable_weights
+
+ def set_weights(self, weights, new_rate=1.0):
+ """
+ Sets the network weights from the given list of weights tensors
+ """
+ feed_dict = {}
+ old_weights, new_weights = self.tp.sess.run([self.get_weights(), weights])
+ for placeholder_idx, new_weight in enumerate(new_weights):
+ feed_dict[self.weights_placeholders[placeholder_idx]]\
+ = new_rate * new_weight + (1 - new_rate) * old_weights[placeholder_idx]
+ self.tp.sess.run(self.update_weights_from_list, feed_dict)
+
+ def write_graph_to_logdir(self, summary_dir):
+ """
+ Writes the tensorflow graph to the logdir for tensorboard visualization
+ :param summary_dir: the path to the logdir
+ """
+ summary_writer = tf.summary.FileWriter(summary_dir)
+ summary_writer.add_graph(self.sess.graph)
diff --git a/architectures/tensorflow_components/embedders.py b/architectures/tensorflow_components/embedders.py
new file mode 100644
index 0000000..83a14c7
--- /dev/null
+++ b/architectures/tensorflow_components/embedders.py
@@ -0,0 +1,73 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tensorflow as tf
+
+
+class InputEmbedder:
+ def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"):
+ self.name = name
+ self.input_size = input_size
+ self.activation_function = activation_function
+ self.input = None
+ self.output = None
+
+ def __call__(self, prev_input_placeholder=None):
+ with tf.variable_scope(self.get_name()):
+ if prev_input_placeholder is None:
+ self.input = tf.placeholder("float", shape=(None,) + self.input_size, name=self.get_name())
+ else:
+ self.input = prev_input_placeholder
+ self._build_module()
+
+ return self.input, self.output
+
+ def _build_module(self):
+ pass
+
+ def get_name(self):
+ return self.name
+
+
+class ImageEmbedder(InputEmbedder):
+ def __init__(self, input_size, input_rescaler=255.0, activation_function=tf.nn.relu, name="embedder"):
+ InputEmbedder.__init__(self, input_size, activation_function, name)
+ self.input_rescaler = input_rescaler
+
+ def _build_module(self):
+ # image observation
+ rescaled_observation_stack = self.input / self.input_rescaler
+ self.observation_conv1 = tf.layers.conv2d(rescaled_observation_stack,
+ filters=32, kernel_size=(8, 8), strides=(4, 4),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv2 = tf.layers.conv2d(self.observation_conv1,
+ filters=64, kernel_size=(4, 4), strides=(2, 2),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv3 = tf.layers.conv2d(self.observation_conv2,
+ filters=64, kernel_size=(3, 3), strides=(1, 1),
+ activation=self.activation_function, data_format='channels_last')
+
+ self.output = tf.contrib.layers.flatten(self.observation_conv3)
+
+
+class VectorEmbedder(InputEmbedder):
+ def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"):
+ InputEmbedder.__init__(self, input_size, activation_function, name)
+
+ def _build_module(self):
+ # vector observation
+ input_layer = tf.contrib.layers.flatten(self.input)
+ self.output = tf.layers.dense(input_layer, 256, activation=self.activation_function)
diff --git a/architectures/tensorflow_components/general_network.py b/architectures/tensorflow_components/general_network.py
new file mode 100644
index 0000000..a3ff5f1
--- /dev/null
+++ b/architectures/tensorflow_components/general_network.py
@@ -0,0 +1,190 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from architectures.tensorflow_components.embedders import *
+from architectures.tensorflow_components.heads import *
+from architectures.tensorflow_components.middleware import *
+from architectures.tensorflow_components.architecture import *
+from configurations import InputTypes, OutputTypes, MiddlewareTypes
+
+
+class GeneralTensorFlowNetwork(TensorFlowArchitecture):
+ def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
+ self.global_network = global_network
+ self.network_is_local = network_is_local
+ self.num_heads_per_network = 1 if tuning_parameters.agent.use_separate_networks_per_head else \
+ len(tuning_parameters.agent.output_types)
+ self.num_networks = 1 if not tuning_parameters.agent.use_separate_networks_per_head else \
+ len(tuning_parameters.agent.output_types)
+ self.input_embedders = []
+ self.output_heads = []
+ self.activation_function = self.get_activation_function(
+ tuning_parameters.agent.hidden_layers_activation_function)
+
+ TensorFlowArchitecture.__init__(self, tuning_parameters, name, global_network, network_is_local)
+
+ def get_activation_function(self, activation_function_string):
+ activation_functions = {
+ 'relu': tf.nn.relu,
+ 'tanh': tf.nn.tanh,
+ 'sigmoid': tf.nn.sigmoid,
+ 'elu': tf.nn.elu,
+ 'none': None
+ }
+ assert activation_function_string in activation_functions.keys(), \
+ "Activation function must be one of the following {}".format(activation_functions.keys())
+ return activation_functions[activation_function_string]
+
+ def get_input_embedder(self, embedder_type):
+ # the observation can be either an image or a vector
+ def get_observation_embedding(with_timestep=False):
+ if self.input_height > 1:
+ return ImageEmbedder((self.input_height, self.input_width, self.input_depth), name="observation")
+ else:
+ return VectorEmbedder((self.input_width + int(with_timestep), self.input_depth), name="observation")
+
+ input_mapping = {
+ InputTypes.Observation: get_observation_embedding(),
+ InputTypes.Measurements: VectorEmbedder(self.measurements_size, name="measurements"),
+ InputTypes.GoalVector: VectorEmbedder(self.measurements_size, name="goal_vector"),
+ InputTypes.Action: VectorEmbedder((self.num_actions,), name="action"),
+ InputTypes.TimedObservation: get_observation_embedding(with_timestep=True),
+ }
+ return input_mapping[embedder_type]
+
+ def get_middleware_embedder(self, middleware_type):
+ return {MiddlewareTypes.LSTM: LSTM_Embedder,
+ MiddlewareTypes.FC: FC_Embedder}.get(middleware_type)(self.activation_function)
+
+ def get_output_head(self, head_type, head_idx, loss_weight=1.):
+ output_mapping = {
+ OutputTypes.Q: QHead,
+ OutputTypes.DuelingQ: DuelingQHead,
+ OutputTypes.V: VHead,
+ OutputTypes.Pi: PolicyHead,
+ OutputTypes.MeasurementsPrediction: MeasurementsPredictionHead,
+ OutputTypes.DNDQ: DNDQHead,
+ OutputTypes.NAF: NAFHead,
+ OutputTypes.PPO: PPOHead,
+ OutputTypes.PPO_V : PPOVHead,
+ OutputTypes.DistributionalQ: DistributionalQHead
+ }
+ return output_mapping[head_type](self.tp, head_idx, loss_weight, self.network_is_local)
+
+ def get_model(self, tuning_parameters):
+ """
+ :param tuning_parameters: A Preset class instance with all the running paramaters
+ :type tuning_parameters: Preset
+ :return: A model
+ """
+ assert len(self.tp.agent.input_types) > 0, "At least one input type should be defined"
+ assert len(self.tp.agent.output_types) > 0, "At least one output type should be defined"
+ assert self.tp.agent.middleware_type is not None, "Exactly one middleware type should be defined"
+ assert len(self.tp.agent.loss_weights) > 0, "At least one loss weight should be defined"
+ assert len(self.tp.agent.output_types) == len(self.tp.agent.loss_weights), \
+ "Number of loss weights should match the number of output types"
+ local_network_in_distributed_training = self.global_network is not None and self.network_is_local
+
+ tuning_parameters.activation_function = self.activation_function
+ done_creating_input_placeholders = False
+
+ for network_idx in range(self.num_networks):
+ with tf.variable_scope('network_{}'.format(network_idx)):
+ ####################
+ # Input Embeddings #
+ ####################
+
+ state_embedding = []
+ for idx, input_type in enumerate(self.tp.agent.input_types):
+ # get the class of the input embedder
+ self.input_embedders.append(self.get_input_embedder(input_type))
+
+ # in the case each head uses a different network, we still reuse the input placeholders
+ prev_network_input_placeholder = self.inputs[idx] if done_creating_input_placeholders else None
+
+ # create the input embedder instance and store the input placeholder and the embedding
+ input_placeholder, embedding = self.input_embedders[-1](prev_network_input_placeholder)
+ if len(self.inputs) < len(self.tp.agent.input_types):
+ self.inputs.append(input_placeholder)
+ state_embedding.append(embedding)
+
+ done_creating_input_placeholders = True
+
+ ##############
+ # Middleware #
+ ##############
+
+ state_embedding = tf.concat(state_embedding, axis=-1) if len(state_embedding) > 1 else state_embedding[0]
+ self.middleware_embedder = self.get_middleware_embedder(self.tp.agent.middleware_type)
+ _, self.state_embedding = self.middleware_embedder(state_embedding)
+
+ ################
+ # Output Heads #
+ ################
+
+ for head_idx in range(self.num_heads_per_network):
+ for head_copy_idx in range(self.tp.agent.num_output_head_copies):
+ if self.tp.agent.use_separate_networks_per_head:
+ # if we use separate networks per head, then the head type corresponds top the network idx
+ head_type_idx = network_idx
+ else:
+ # if we use a single network with multiple heads, then the head type is the current head idx
+ head_type_idx = head_idx
+ self.output_heads.append(self.get_output_head(self.tp.agent.output_types[head_type_idx],
+ head_copy_idx,
+ self.tp.agent.loss_weights[head_type_idx]))
+
+ if self.tp.agent.stop_gradients_from_head[head_idx]:
+ head_input = tf.stop_gradient(self.state_embedding)
+ else:
+ head_input = self.state_embedding
+
+ # build the head
+ if self.network_is_local:
+ output, target_placeholder, input_placeholder = self.output_heads[-1](head_input)
+ self.targets.extend(target_placeholder)
+ else:
+ output, input_placeholder = self.output_heads[-1](head_input)
+
+ self.outputs.extend(output)
+ self.inputs.extend(input_placeholder)
+
+ # Losses
+ self.losses = tf.losses.get_losses(self.name)
+ self.losses += tf.losses.get_regularization_losses(self.name)
+ self.total_loss = tf.losses.compute_weighted_loss(self.losses, scope=self.name)
+
+ # Learning rate
+ if self.tp.learning_rate_decay_rate != 0:
+ self.tp.learning_rate = tf.train.exponential_decay(
+ self.tp.learning_rate, self.global_step, decay_steps=self.tp.learning_rate_decay_steps,
+ decay_rate=self.tp.learning_rate_decay_rate, staircase=True)
+
+ # Optimizer
+ if local_network_in_distributed_training and \
+ hasattr(self.tp.agent, "shared_optimizer") and self.tp.agent.shared_optimizer:
+ # distributed training and this is the local network instantiation
+ self.optimizer = self.global_network.optimizer
+ else:
+ if tuning_parameters.agent.optimizer_type == 'Adam':
+ self.optimizer = tf.train.AdamOptimizer(learning_rate=tuning_parameters.learning_rate)
+ elif tuning_parameters.agent.optimizer_type == 'RMSProp':
+ self.optimizer = tf.train.RMSPropOptimizer(self.tp.learning_rate, decay=0.9, epsilon=0.01)
+ elif tuning_parameters.agent.optimizer_type == 'LBFGS':
+ self.optimizer = tf.contrib.opt.ScipyOptimizerInterface(self.total_loss, method='L-BFGS-B',
+ options={'maxiter': 25})
+ else:
+ raise Exception("{} is not a valid optimizer type".format(tuning_parameters.agent.optimizer_type))
diff --git a/architectures/tensorflow_components/heads.py b/architectures/tensorflow_components/heads.py
new file mode 100644
index 0000000..05bb3e6
--- /dev/null
+++ b/architectures/tensorflow_components/heads.py
@@ -0,0 +1,481 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tensorflow as tf
+import numpy as np
+from utils import force_list
+
+
+# Used to initialize weights for policy and value output layers
+def normalized_columns_initializer(std=1.0):
+ def _initializer(shape, dtype=None, partition_info=None):
+ out = np.random.randn(*shape).astype(np.float32)
+ out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
+ return tf.constant(out)
+ return _initializer
+
+
+class Head:
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ self.head_idx = head_idx
+ self.name = "head"
+ self.output = []
+ self.loss = []
+ self.loss_type = []
+ self.regularizations = []
+ self.loss_weight = force_list(loss_weight)
+ self.target = []
+ self.input = []
+ self.is_local = is_local
+
+ def __call__(self, input_layer):
+ """
+ Wrapper for building the module graph including scoping and loss creation
+ :param input_layer: the input to the graph
+ :return: the output of the last layer and the target placeholder
+ """
+ with tf.variable_scope(self.get_name(), initializer=tf.contrib.layers.xavier_initializer()):
+ self._build_module(input_layer)
+
+ self.output = force_list(self.output)
+ self.target = force_list(self.target)
+ self.input = force_list(self.input)
+ self.loss_type = force_list(self.loss_type)
+ self.loss = force_list(self.loss)
+ self.regularizations = force_list(self.regularizations)
+ if self.is_local:
+ self.set_loss()
+
+ if self.is_local:
+ return self.output, self.target, self.input
+ else:
+ return self.output, self.input
+
+ def _build_module(self, input_layer):
+ """
+ Builds the graph of the module
+ :param input_layer: the input to the graph
+ :return: None
+ """
+ pass
+
+ def get_name(self):
+ """
+ Get a formatted name for the module
+ :return: the formatted name
+ """
+ return '{}_{}'.format(self.name, self.head_idx)
+
+ def set_loss(self):
+ """
+ Creates a target placeholder and loss function for each loss_type and regularization
+ :param loss_type: a tensorflow loss function
+ :param scope: the name scope to include the tensors in
+ :return: None
+ """
+ # add losses and target placeholder
+ for idx in range(len(self.loss_type)):
+ target = tf.placeholder('float', self.output[idx].shape, '{}_target'.format(self.get_name()))
+ self.target.append(target)
+ loss = self.loss_type[idx](self.target[-1], self.output[idx],
+ weights=self.loss_weight[idx], scope=self.get_name())
+ self.loss.append(loss)
+
+ # add regularizations
+ for regularization in self.regularizations:
+ self.loss.append(regularization)
+
+
+class QHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'q_values_head'
+ self.num_actions = tuning_parameters.env_instance.action_space_size
+ if tuning_parameters.agent.replace_mse_with_huber_loss:
+ self.loss_type = tf.losses.huber_loss
+ else:
+ self.loss_type = tf.losses.mean_squared_error
+
+ def _build_module(self, input_layer):
+ # Standard Q Network
+ self.output = tf.layers.dense(input_layer, self.num_actions, name='output')
+
+
+class DuelingQHead(QHead):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ QHead.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+
+ def _build_module(self, input_layer):
+ # state value tower - V
+ with tf.variable_scope("state_value"):
+ state_value = tf.layers.dense(input_layer, 256, activation=tf.nn.relu)
+ state_value = tf.layers.dense(state_value, 1)
+ # state_value = tf.expand_dims(state_value, axis=-1)
+
+ # action advantage tower - A
+ with tf.variable_scope("action_advantage"):
+ action_advantage = tf.layers.dense(input_layer, 256, activation=tf.nn.relu)
+ action_advantage = tf.layers.dense(action_advantage, self.num_actions)
+ action_advantage = action_advantage - tf.reduce_mean(action_advantage)
+
+ # merge to state-action value function Q
+ self.output = tf.add(state_value, action_advantage, name='output')
+
+
+class VHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'v_values_head'
+ if tuning_parameters.agent.replace_mse_with_huber_loss:
+ self.loss_type = tf.losses.huber_loss
+ else:
+ self.loss_type = tf.losses.mean_squared_error
+
+ def _build_module(self, input_layer):
+ # Standard V Network
+ self.output = tf.layers.dense(input_layer, 1, name='output',
+ kernel_initializer=normalized_columns_initializer(1.0))
+
+
+class PolicyHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'policy_values_head'
+ self.num_actions = tuning_parameters.env_instance.action_space_size
+ self.output_scale = np.max(tuning_parameters.env_instance.action_space_abs_range)
+ self.discrete_controls = tuning_parameters.env_instance.discrete_controls
+ self.exploration_policy = tuning_parameters.exploration.policy
+ self.exploration_variance = 2*self.output_scale*tuning_parameters.exploration.initial_noise_variance_percentage
+ if not self.discrete_controls and not self.output_scale:
+ raise ValueError("For continuous controls, an output scale for the network must be specified")
+ self.beta = tuning_parameters.agent.beta_entropy
+
+ def _build_module(self, input_layer):
+ eps = 1e-15
+ if self.discrete_controls:
+ self.actions = tf.placeholder(tf.int32, [None], name="actions")
+ else:
+ self.actions = tf.placeholder(tf.float32, [None, self.num_actions], name="actions")
+ self.input = [self.actions]
+
+ # Policy Head
+ if self.discrete_controls:
+ policy_values = tf.layers.dense(input_layer, self.num_actions)
+ self.policy_mean = tf.nn.softmax(policy_values, name="policy")
+
+ # define the distributions for the policy and the old policy
+ self.policy_distribution = tf.contrib.distributions.Categorical(probs=self.policy_mean)
+ self.output = self.policy_mean
+ else:
+ # mean
+ policy_values_mean = tf.layers.dense(input_layer, self.num_actions, activation=tf.nn.tanh)
+ self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
+
+ self.output = [self.policy_mean]
+
+ # std
+ if self.exploration_policy == 'ContinuousEntropy':
+ policy_values_std = tf.layers.dense(input_layer, self.num_actions,
+ kernel_initializer=normalized_columns_initializer(0.01))
+ self.policy_std = tf.nn.softplus(policy_values_std, name='output_variance') + eps
+
+ self.output.append(self.policy_std)
+
+ else:
+ self.policy_std = tf.constant(self.exploration_variance, dtype='float32', shape=(self.num_actions,))
+
+ # define the distributions for the policy and the old policy
+ self.policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.policy_mean,
+ self.policy_std)
+
+ if self.is_local:
+ # add entropy regularization
+ if self.beta:
+ self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
+ self.regularizations = -tf.multiply(self.beta, self.entropy, name='entropy_regularization')
+ tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
+
+ # calculate loss
+ self.action_log_probs_wrt_policy = self.policy_distribution.log_prob(self.actions)
+ self.advantages = tf.placeholder(tf.float32, [None], name="advantages")
+ self.target = self.advantages
+ self.loss = -tf.reduce_mean(self.action_log_probs_wrt_policy * self.advantages)
+ tf.losses.add_loss(self.loss_weight[0] * self.loss)
+
+
+class MeasurementsPredictionHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'future_measurements_head'
+ self.num_actions = tuning_parameters.env_instance.action_space_size
+ self.num_measurements = tuning_parameters.env.measurements_size[0] \
+ if tuning_parameters.env.measurements_size else 0
+ self.num_prediction_steps = tuning_parameters.agent.num_predicted_steps_ahead
+ self.multi_step_measurements_size = self.num_measurements * self.num_prediction_steps
+ if tuning_parameters.agent.replace_mse_with_huber_loss:
+ self.loss_type = tf.losses.huber_loss
+ else:
+ self.loss_type = tf.losses.mean_squared_error
+
+ def _build_module(self, input_layer):
+ # This is almost exactly the same as Dueling Network but we predict the future measurements for each action
+ # actions expectation tower (expectation stream) - E
+ with tf.variable_scope("expectation_stream"):
+ expectation_stream = tf.layers.dense(input_layer, 256, activation=tf.nn.elu)
+ expectation_stream = tf.layers.dense(expectation_stream, self.multi_step_measurements_size)
+ expectation_stream = tf.expand_dims(expectation_stream, axis=1)
+
+ # action fine differences tower (action stream) - A
+ with tf.variable_scope("action_stream"):
+ action_stream = tf.layers.dense(input_layer, 256, activation=tf.nn.elu)
+ action_stream = tf.layers.dense(action_stream, self.num_actions * self.multi_step_measurements_size)
+ action_stream = tf.reshape(action_stream,
+ (tf.shape(action_stream)[0], self.num_actions, self.multi_step_measurements_size))
+ action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keep_dims=True)
+
+ # merge to future measurements predictions
+ self.output = tf.add(expectation_stream, action_stream, name='output')
+
+
+class DNDQHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'dnd_q_values_head'
+ self.num_actions = tuning_parameters.env_instance.action_space_size
+ self.DND_size = tuning_parameters.agent.dnd_size
+ self.DND_key_error_threshold = tuning_parameters.agent.DND_key_error_threshold
+ self.l2_norm_added_delta = tuning_parameters.agent.l2_norm_added_delta
+ self.new_value_shift_coefficient = tuning_parameters.agent.new_value_shift_coefficient
+ self.number_of_nn = tuning_parameters.agent.number_of_knn
+ if tuning_parameters.agent.replace_mse_with_huber_loss:
+ self.loss_type = tf.losses.huber_loss
+ else:
+ self.loss_type = tf.losses.mean_squared_error
+
+ def _build_module(self, input_layer):
+ # DND based Q head
+ from memories import differentiable_neural_dictionary
+ self.DND = differentiable_neural_dictionary. QDND(
+ self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
+ key_error_threshold=self.DND_key_error_threshold)
+
+ # Retrieve info from DND dictionary
+ self.action = tf.placeholder(tf.int8, [None], name="action")
+ self.input = self.action
+ result = tf.py_func(self.DND.query,
+ [input_layer, self.action, self.number_of_nn],
+ [tf.float64, tf.float64])
+ self.dnd_embeddings = tf.to_float(result[0])
+ self.dnd_values = tf.to_float(result[1])
+
+ # DND calculation
+ square_diff = tf.square(self.dnd_embeddings - tf.expand_dims(input_layer, 1))
+ distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta]
+ weights = 1.0 / distances
+ normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True)
+ self.output = tf.reduce_sum(self.dnd_values * normalised_weights, axis=1)
+
+
+class NAFHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'naf_q_values_head'
+ self.num_actions = tuning_parameters.env_instance.action_space_size
+ self.output_scale = np.max(tuning_parameters.env_instance.action_space_abs_range)
+ if tuning_parameters.agent.replace_mse_with_huber_loss:
+ self.loss_type = tf.losses.huber_loss
+ else:
+ self.loss_type = tf.losses.mean_squared_error
+
+ def _build_module(self, input_layer):
+ # NAF
+ self.action = tf.placeholder(tf.float32, [None, self.num_actions], name="action")
+ self.input = self.action
+
+ # V Head
+ self.V = tf.layers.dense(input_layer, 1, name='V')
+
+ # mu Head
+ mu_unscaled = tf.layers.dense(input_layer, self.num_actions, activation=tf.nn.tanh, name='mu_unscaled')
+ self.mu = tf.multiply(mu_unscaled, self.output_scale, name='mu')
+
+ # A Head
+ # l_vector is a vector that includes a lower-triangular matrix values
+ self.l_vector = tf.layers.dense(input_layer, (self.num_actions * (self.num_actions + 1)) / 2, name='l_vector')
+
+ # Convert l to a lower triangular matrix and exponentiate its diagonal
+
+ i = 0
+ columns = []
+ for col in range(self.num_actions):
+ start_row = col
+ num_non_zero_elements = self.num_actions - start_row
+ zeros_column_part = tf.zeros_like(self.l_vector[:, 0:start_row])
+ diag_element = tf.expand_dims(tf.exp(self.l_vector[:, i]), 1)
+ non_zeros_non_diag_column_part = self.l_vector[:, (i + 1):(i + num_non_zero_elements)]
+ columns.append(tf.concat([zeros_column_part, diag_element, non_zeros_non_diag_column_part], axis=1))
+ i += num_non_zero_elements
+ self.L = tf.transpose(tf.stack(columns, axis=1), (0, 2, 1))
+
+ # P = L*L^T
+ self.P = tf.matmul(self.L, tf.transpose(self.L, (0, 2, 1)))
+
+ # A = -1/2 * (u - mu)^T * P * (u - mu)
+ action_diff = tf.expand_dims(self.action - self.mu, -1)
+ a_matrix_form = -0.5 * tf.matmul(tf.transpose(action_diff, (0, 2, 1)), tf.matmul(self.P, action_diff))
+ self.A = tf.reshape(a_matrix_form, [-1, 1])
+
+ # Q Head
+ self.Q = tf.add(self.V, self.A, name='Q')
+
+ self.output = self.Q
+
+
+class PPOHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'ppo_head'
+ self.num_actions = tuning_parameters.env_instance.action_space_size
+ self.discrete_controls = tuning_parameters.env_instance.discrete_controls
+ self.output_scale = np.max(tuning_parameters.env_instance.action_space_abs_range)
+ self.kl_coefficient = tf.Variable(tuning_parameters.agent.initial_kl_coefficient,
+ trainable=False, name='kl_coefficient')
+ self.kl_cutoff = 2*tuning_parameters.agent.target_kl_divergence
+ self.high_kl_penalty_coefficient = tuning_parameters.agent.high_kl_penalty_coefficient
+ self.clip_likelihood_ratio_using_epsilon = tuning_parameters.agent.clip_likelihood_ratio_using_epsilon
+ self.use_kl_regularization = tuning_parameters.agent.use_kl_regularization
+ self.beta = tuning_parameters.agent.beta_entropy
+
+
+ def _build_module(self, input_layer):
+ eps = 1e-15
+
+ if self.discrete_controls:
+ self.actions = tf.placeholder(tf.int32, [None], name="actions")
+ else:
+ self.actions = tf.placeholder(tf.float32, [None, self.num_actions], name="actions")
+ self.old_policy_mean = tf.placeholder(tf.float32, [None, self.num_actions], "old_policy_mean")
+ self.old_policy_std = tf.placeholder(tf.float32, [None, self.num_actions], "old_policy_std")
+
+ # Policy Head
+ if self.discrete_controls:
+ self.input = [self.actions, self.old_policy_mean]
+ policy_values = tf.layers.dense(input_layer, self.num_actions)
+ self.policy_mean = tf.nn.softmax(policy_values, name="policy")
+
+ # define the distributions for the policy and the old policy
+ self.policy_distribution = tf.contrib.distributions.Categorical(probs=self.policy_mean)
+ self.old_policy_distribution = tf.contrib.distributions.Categorical(probs=self.old_policy_mean)
+
+ self.output = self.policy_mean
+ else:
+ self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
+ self.policy_mean = tf.layers.dense(input_layer, self.num_actions, name='policy_mean')
+ self.policy_logstd = tf.Variable(np.zeros((1, self.num_actions)), dtype='float32')
+ self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std')
+
+ # define the distributions for the policy and the old policy
+ self.policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.policy_mean,
+ self.policy_std)
+ self.old_policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.old_policy_mean,
+ self.old_policy_std)
+
+ self.output = [self.policy_mean, self.policy_std]
+
+ self.action_probs_wrt_policy = tf.exp(self.policy_distribution.log_prob(self.actions))
+ self.action_probs_wrt_old_policy = tf.exp(self.old_policy_distribution.log_prob(self.actions))
+ self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
+
+ # add kl divergence regularization
+ self.kl_divergence = tf.reduce_mean(tf.contrib.distributions.kl_divergence(self.old_policy_distribution,
+ self.policy_distribution))
+ if self.use_kl_regularization:
+ # no clipping => use kl regularization
+ self.weighted_kl_divergence = tf.multiply(self.kl_coefficient, self.kl_divergence)
+ self.regularizations = self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \
+ tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff))
+ tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
+
+ # calculate surrogate loss
+ self.advantages = tf.placeholder(tf.float32, [None], name="advantages")
+ self.target = self.advantages
+ self.likelihood_ratio = self.action_probs_wrt_policy / self.action_probs_wrt_old_policy
+ if self.clip_likelihood_ratio_using_epsilon is not None:
+ max_value = 1 + self.clip_likelihood_ratio_using_epsilon
+ min_value = 1 - self.clip_likelihood_ratio_using_epsilon
+ self.clipped_likelihood_ratio = tf.clip_by_value(self.likelihood_ratio, min_value, max_value)
+ self.scaled_advantages = tf.minimum(self.likelihood_ratio * self.advantages,
+ self.clipped_likelihood_ratio * self.advantages)
+ else:
+ self.scaled_advantages = self.likelihood_ratio * self.advantages
+ # minus sign is in order to set an objective to minimize (we actually strive for maximizing the surrogate loss)
+ self.surrogate_loss = -tf.reduce_mean(self.scaled_advantages)
+ if self.is_local:
+ # add entropy regularization
+ if self.beta:
+ self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
+ self.regularizations = -tf.multiply(self.beta, self.entropy, name='entropy_regularization')
+ tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
+
+ self.loss = self.surrogate_loss
+ tf.losses.add_loss(self.loss)
+
+
+class PPOVHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'ppo_v_head'
+ self.clip_likelihood_ratio_using_epsilon = tuning_parameters.agent.clip_likelihood_ratio_using_epsilon
+
+ def _build_module(self, input_layer):
+ self.old_policy_value = tf.placeholder(tf.float32, [None], "old_policy_values")
+ self.input = [self.old_policy_value]
+ self.output = tf.layers.dense(input_layer, 1, name='output',
+ kernel_initializer=normalized_columns_initializer(1.0))
+ self.target = self.total_return = tf.placeholder(tf.float32, [None], name="total_return")
+
+ value_loss_1 = tf.square(self.output - self.target)
+ value_loss_2 = tf.square(self.old_policy_value +
+ tf.clip_by_value(self.output - self.old_policy_value,
+ -self.clip_likelihood_ratio_using_epsilon,
+ self.clip_likelihood_ratio_using_epsilon) - self.target)
+ self.vf_loss = tf.reduce_mean(tf.maximum(value_loss_1, value_loss_2))
+ self.loss = self.vf_loss
+ tf.losses.add_loss(self.loss)
+
+
+class DistributionalQHead(Head):
+ def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
+ Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
+ self.name = 'distributional_dqn_head'
+ self.num_actions = tuning_parameters.env_instance.action_space_size
+ self.num_atoms = tuning_parameters.agent.atoms
+
+ def _build_module(self, input_layer):
+ self.actions = tf.placeholder(tf.int32, [None], name="actions")
+ self.input = [self.actions]
+
+ values_distribution = tf.layers.dense(input_layer, self.num_actions * self.num_atoms)
+ values_distribution = tf.reshape(values_distribution, (tf.shape(values_distribution)[0], self.num_actions, self.num_atoms))
+ # softmax on atoms dimension
+ self.output = tf.nn.softmax(values_distribution)
+
+ # calculate cross entropy loss
+ self.distributions = tf.placeholder(tf.float32, shape=(None, self.num_actions, self.num_atoms), name="distributions")
+ self.target = self.distributions
+ self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
+ tf.losses.add_loss(self.loss)
+
diff --git a/architectures/tensorflow_components/middleware.py b/architectures/tensorflow_components/middleware.py
new file mode 100644
index 0000000..0d28bc3
--- /dev/null
+++ b/architectures/tensorflow_components/middleware.py
@@ -0,0 +1,65 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tensorflow as tf
+import numpy as np
+
+
+class MiddlewareEmbedder:
+ def __init__(self, activation_function=tf.nn.relu, name="middleware_embedder"):
+ self.name = name
+ self.input = None
+ self.output = None
+ self.activation_function = activation_function
+
+ def __call__(self, input_layer):
+ with tf.variable_scope(self.get_name()):
+ self.input = input_layer
+ self._build_module()
+
+ return self.input, self.output
+
+ def _build_module(self):
+ pass
+
+ def get_name(self):
+ return self.name
+
+
+class LSTM_Embedder(MiddlewareEmbedder):
+ def _build_module(self):
+
+ middleware = tf.layers.dense(self.input, 512, activation=self.activation_function)
+ lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple=True)
+ self.c_init = np.zeros((1, lstm_cell.state_size.c), np.float32)
+ self.h_init = np.zeros((1, lstm_cell.state_size.h), np.float32)
+ self.state_init = [self.c_init, self.h_init]
+ self.c_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.c])
+ self.h_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.h])
+ self.state_in = (self.c_in, self.h_in)
+ rnn_in = tf.expand_dims(middleware, [0])
+ step_size = tf.shape(middleware)[:1]
+ state_in = tf.contrib.rnn.LSTMStateTuple(self.c_in, self.h_in)
+ lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
+ lstm_cell, rnn_in, initial_state=state_in, sequence_length=step_size, time_major=False)
+ lstm_c, lstm_h = lstm_state
+ self.state_out = (lstm_c[:1, :], lstm_h[:1, :])
+ self.output = tf.reshape(lstm_outputs, [-1, 256])
+
+
+class FC_Embedder(MiddlewareEmbedder):
+ def _build_module(self):
+ self.output = tf.layers.dense(self.input, 512, activation=self.activation_function)
diff --git a/architectures/tensorflow_components/shared_variables.py b/architectures/tensorflow_components/shared_variables.py
new file mode 100644
index 0000000..8ee623e
--- /dev/null
+++ b/architectures/tensorflow_components/shared_variables.py
@@ -0,0 +1,81 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tensorflow as tf
+import numpy as np
+
+
+class SharedRunningStats(object):
+ def __init__(self, tuning_parameters, replicated_device, epsilon=1e-2, shape=(), name=""):
+ self.tp = tuning_parameters
+ with tf.device(replicated_device):
+ with tf.variable_scope(name):
+ self._sum = tf.get_variable(
+ dtype=tf.float64,
+ shape=shape,
+ initializer=tf.constant_initializer(0.0),
+ name="running_sum", trainable=False)
+ self._sum_squared = tf.get_variable(
+ dtype=tf.float64,
+ shape=shape,
+ initializer=tf.constant_initializer(epsilon),
+ name="running_sum_squared", trainable=False)
+ self._count = tf.get_variable(
+ dtype=tf.float64,
+ shape=(),
+ initializer=tf.constant_initializer(epsilon),
+ name="count", trainable=False)
+
+ self._shape = shape
+ self._mean = tf.to_float(self._sum / self._count)
+ self._std = tf.sqrt(tf.maximum(tf.to_float(self._sum_squared / self._count) - tf.square(self._mean), 1e-2))
+
+ self.new_sum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
+ self.new_sum_squared = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var')
+ self.newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
+
+ self._inc_sum = tf.assign_add(self._sum, self.new_sum, use_locking=True)
+ self._inc_sum_squared = tf.assign_add(self._sum_squared, self.new_sum_squared, use_locking=True)
+ self._inc_count = tf.assign_add(self._count, self.newcount, use_locking=True)
+
+ def push(self, x):
+ x = x.astype('float64')
+ self.tp.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count],
+ feed_dict={
+ self.new_sum: x.sum(axis=0).ravel(),
+ self.new_sum_squared: np.square(x).sum(axis=0).ravel(),
+ self.newcount: np.array(len(x), dtype='float64')
+ })
+
+ @property
+ def n(self):
+ return self.tp.sess.run(self._count)
+
+ @property
+ def mean(self):
+ return self.tp.sess.run(self._mean)
+
+ @property
+ def var(self):
+ return self.std ** 2
+
+ @property
+ def std(self):
+ return self.tp.sess.run(self._std)
+
+ @property
+ def shape(self):
+ return self._shape
\ No newline at end of file
diff --git a/coach.py b/coach.py
new file mode 100644
index 0000000..97565c2
--- /dev/null
+++ b/coach.py
@@ -0,0 +1,316 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys, inspect, re
+import os
+import json
+import presets
+from presets import *
+from utils import set_gpu, list_all_classes_in_module
+from architectures import *
+from environments import *
+from agents import *
+from utils import *
+from logger import screen, logger
+import argparse
+from subprocess import Popen
+import datetime
+import presets
+
+screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
+
+time_started = datetime.datetime.now()
+cur_time = time_started.time()
+cur_date = time_started.date()
+
+
+def get_experiment_path(general_experiments_path):
+ if not os.path.exists(general_experiments_path):
+ os.makedirs(general_experiments_path)
+ experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}'
+ .format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month),
+ cur_date.year, logger.two_digits(cur_time.hour),
+ logger.two_digits(cur_time.minute)))
+ i = 0
+ while True:
+ if os.path.exists(experiment_path):
+ experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}_{}'
+ .format(cur_date.day, cur_date.month, cur_date.year, cur_time.hour,
+ cur_time.minute, i))
+ i += 1
+ else:
+ os.makedirs(experiment_path)
+ return experiment_path
+
+
+def set_framework(framework_type):
+ # choosing neural network framework
+ framework = Frameworks().get(framework_type)
+ sess = None
+ if framework == Frameworks.TensorFlow:
+ import tensorflow as tf
+ config = tf.ConfigProto()
+ config.allow_soft_placement = True
+ config.gpu_options.allow_growth = True
+ config.gpu_options.per_process_gpu_memory_fraction = 0.2
+ sess = tf.Session(config=config)
+ elif framework == Frameworks.Neon:
+ import ngraph as ng
+ sess = ng.transformers.make_transformer()
+ screen.log_title("Using {} framework".format(Frameworks().to_string(framework)))
+ return sess
+
+
+def check_input_and_fill_run_dict(parser):
+ args = parser.parse_args()
+
+ # if no arg is given
+ if len(sys.argv) == 1:
+ parser.print_help()
+ exit(0)
+
+ # list available presets
+ if args.list:
+ presets_lists = list_all_classes_in_module(presets)
+ screen.log_title("Available Presets:")
+ for preset in presets_lists:
+ print(preset)
+ sys.exit(0)
+
+ # check inputs
+ try:
+ # num_workers = int(args.num_workers)
+ num_workers = int(re.match("^\d+$", args.num_workers).group(0))
+ except ValueError:
+ screen.error("Parameter num_workers should be an integer.")
+ exit(1)
+
+ preset_names = list_all_classes_in_module(presets)
+ if args.preset is not None and args.preset not in preset_names:
+ screen.error("A non-existing preset was selected. ")
+ exit(1)
+
+ if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
+ screen.error("The requested checkpoint folder to load from does not exist. ")
+ exit(1)
+
+ if args.save_model_sec is not None:
+ try:
+ args.save_model_sec = int(args.save_model_sec)
+ except ValueError:
+ screen.error("Parameter save_model_sec should be an integer.")
+ exit(1)
+
+ if args.preset is None and (args.agent_type is None or args.environment_type is None
+ or args.exploration_policy_type is None):
+ screen.error('When no preset is given for Coach to run, the user is expected to input the desired agent_type,'
+ ' environment_type and exploration_policy_type to assemble a preset. '
+ '\nAt least one of these parameters was not given.')
+ exit(1)
+
+ experiment_name = args.experiment_name
+
+ if args.experiment_name == '':
+ experiment_name = screen.ask_input("Please enter an experiment name: ")
+
+ experiment_name = experiment_name.replace(" ", "_")
+ match = re.match("^$|^\w{1,100}$", experiment_name)
+
+ if match is None:
+ screen.error('Experiment name must be composed only of alphanumeric letters and underscores and should not be '
+ 'longer than 100 characters.')
+ exit(1)
+ experiment_path = os.path.join('./experiments/', match.group(0))
+ experiment_path = get_experiment_path(experiment_path)
+
+ # fill run_dict
+ run_dict = dict()
+ run_dict['agent_type'] = args.agent_type
+ run_dict['environment_type'] = args.environment_type
+ run_dict['exploration_policy_type'] = args.exploration_policy_type
+ run_dict['preset'] = args.preset
+ run_dict['custom_parameter'] = args.custom_parameter
+ run_dict['experiment_path'] = experiment_path
+ run_dict['framework'] = Frameworks().get(args.framework)
+
+ # multi-threading parameters
+ run_dict['num_threads'] = num_workers
+
+ # checkpoints
+ run_dict['save_model_sec'] = args.save_model_sec
+ run_dict['save_model_dir'] = experiment_path if args.save_model_sec is not None else None
+ run_dict['checkpoint_restore_dir'] = args.checkpoint_restore_dir
+
+ # visualization
+ run_dict['visualization.dump_gifs'] = args.dump_gifs
+ run_dict['visualization.render'] = args.render
+
+ return args, run_dict
+
+
+def run_dict_to_json(_run_dict, task_id=''):
+ if task_id != '':
+ json_path = os.path.join(_run_dict['experiment_path'], 'run_dict_worker{}.json'.format(task_id))
+ else:
+ json_path = os.path.join(_run_dict['experiment_path'], 'run_dict.json')
+
+ with open(json_path, 'w') as outfile:
+ json.dump(_run_dict, outfile, indent=2)
+
+ return json_path
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p', '--preset',
+ help="(string) Name of a preset to run (as configured in presets.py)",
+ default=None,
+ type=str)
+ parser.add_argument('-l', '--list',
+ help="(flag) List all available presets",
+ action='store_true')
+ parser.add_argument('-e', '--experiment_name',
+ help="(string) Experiment name to be used to store the results.",
+ default='',
+ type=str)
+ parser.add_argument('-r', '--render',
+ help="(flag) Render environment",
+ action='store_true')
+ parser.add_argument('-f', '--framework',
+ help="(string) Neural network framework. Available values: tensorflow, neon",
+ default='tensorflow',
+ type=str)
+ parser.add_argument('-n', '--num_workers',
+ help="(int) Number of workers for multi-process based agents, e.g. A3C",
+ default='1',
+ type=str)
+ parser.add_argument('-v', '--verbose',
+ help="(flag) Don't suppress TensorFlow debug prints.",
+ action='store_true')
+ parser.add_argument('-s', '--save_model_sec',
+ help="(int) Time in seconds between saving checkpoints of the model.",
+ default=None,
+ type=int)
+ parser.add_argument('-crd', '--checkpoint_restore_dir',
+ help='(string) Path to a folder containing a checkpoint to restore the model from.',
+ type=str)
+ parser.add_argument('-dg', '--dump_gifs',
+ help="(flag) Enable the gif saving functionality.",
+ action='store_true')
+ parser.add_argument('-at', '--agent_type',
+ help="(string) Choose an agent type class to override on top of the selected preset. "
+ "If no preset is defined, a preset can be set from the command-line by combining settings "
+ "which are set by using --agent_type, --experiment_type, --environemnt_type",
+ default=None,
+ type=str)
+ parser.add_argument('-et', '--environment_type',
+ help="(string) Choose an environment type class to override on top of the selected preset."
+ "If no preset is defined, a preset can be set from the command-line by combining settings "
+ "which are set by using --agent_type, --experiment_type, --environemnt_type",
+ default=None,
+ type=str)
+ parser.add_argument('-ept', '--exploration_policy_type',
+ help="(string) Choose an exploration policy type class to override on top of the selected "
+ "preset."
+ "If no preset is defined, a preset can be set from the command-line by combining settings "
+ "which are set by using --agent_type, --experiment_type, --environemnt_type"
+ ,
+ default=None,
+ type=str)
+ parser.add_argument('-cp', '--custom_parameter',
+ help="(string) Semicolon separated parameters used to override specific parameters on top of"
+ " the selected preset (or on top of the command-line assembled one). "
+ "Whenever a parameter value is a string, it should be inputted as '\\\"string\\\"'. "
+ "For ex.: "
+ "\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
+ default=None,
+ type=str)
+
+ args, run_dict = check_input_and_fill_run_dict(parser)
+
+ # turn TF debug prints off
+ if not args.verbose and args.framework.lower() == 'tensorflow':
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+
+ # dump documentation
+ logger.set_dump_dir(run_dict['experiment_path'], add_timestamp=True)
+
+ # Single-threaded runs
+ if run_dict['num_threads'] == 1:
+ # set tuning parameters
+ json_run_dict_path = run_dict_to_json(run_dict)
+ tuning_parameters = json_to_preset(json_run_dict_path)
+ tuning_parameters.sess = set_framework(args.framework)
+
+ # Single-thread runs
+ tuning_parameters.task_index = 0
+ env_instance = create_environment(tuning_parameters)
+ agent = eval(tuning_parameters.agent.type + '(env_instance, tuning_parameters)')
+ agent.improve()
+
+ # Multi-threaded runs
+ else:
+ assert args.framework.lower() == 'tensorflow', "Distributed training works only with TensorFlow"
+
+ # set parameter server and workers addresses
+ ps_hosts = "localhost:{}".format(get_open_port())
+ worker_hosts = ",".join(["localhost:{}".format(get_open_port()) for i in range(run_dict['num_threads'] + 1)])
+
+ # Make sure to disable GPU so that all the workers will use the CPU
+ set_cpu()
+
+ # create a parameter server
+ Popen(["python",
+ "./parallel_actor.py",
+ "--ps_hosts={}".format(ps_hosts),
+ "--worker_hosts={}".format(worker_hosts),
+ "--job_name=ps"])
+
+ screen.log_title("*** Distributed Training ***")
+ time.sleep(1)
+
+ # create N training workers and 1 evaluating worker
+ workers = []
+
+ for i in range(run_dict['num_threads'] + 1):
+ # this is the evaluation worker
+ run_dict['task_id'] = i
+ if i == run_dict['num_threads']:
+ run_dict['evaluate_only'] = True
+ run_dict['visualization.render'] = args.render
+ else:
+ run_dict['evaluate_only'] = False
+ run_dict['visualization.render'] = False # #In a parallel setting, only the evaluation agent renders
+
+ json_run_dict_path = run_dict_to_json(run_dict, i)
+ workers_args = ["python", "./parallel_actor.py",
+ "--ps_hosts={}".format(ps_hosts),
+ "--worker_hosts={}".format(worker_hosts),
+ "--job_name=worker",
+ "--load_json={}".format(json_run_dict_path)]
+
+ p = Popen(workers_args)
+
+ if i != run_dict['num_threads']:
+ workers.append(p)
+ else:
+ evaluation_worker = p
+
+ # wait for all workers
+ [w.wait() for w in workers]
+ evaluation_worker.kill()
+
+
diff --git a/configurations.py b/configurations.py
new file mode 100644
index 0000000..aff718a
--- /dev/null
+++ b/configurations.py
@@ -0,0 +1,532 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from utils import Enum
+import json
+from logger import screen, logger
+
+
+class Frameworks(Enum):
+ TensorFlow = 1
+ Neon = 2
+
+
+class InputTypes:
+ Observation = 1
+ Measurements = 2
+ GoalVector = 3
+ Action = 4
+ TimedObservation = 5
+
+
+class OutputTypes:
+ Q = 1
+ DuelingQ = 2
+ V = 3
+ Pi = 4
+ MeasurementsPrediction = 5
+ DNDQ = 6
+ NAF = 7
+ PPO = 8
+ PPO_V = 9
+ DistributionalQ = 10
+
+
+class MiddlewareTypes:
+ LSTM = 1
+ FC = 2
+
+
+class AgentParameters:
+ agent = ''
+
+ # Architecture parameters
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.Q]
+ middleware_type = MiddlewareTypes.FC
+ loss_weights = [1.0]
+ stop_gradients_from_head = [False]
+ num_output_head_copies = 1
+ use_measurements = False
+ use_accumulated_reward_as_measurement = False
+ add_a_normalized_timestep_to_the_observation = False
+ l2_regularization = 0
+ hidden_layers_activation_function = 'relu'
+ optimizer_type = 'Adam'
+ async_training = False
+ use_separate_networks_per_head = False
+
+ # Agent parameters
+ num_consecutive_playing_steps = 1
+ num_consecutive_training_steps = 1
+ bootstrap_total_return_from_old_policy = False
+ n_step = -1
+ num_episodes_in_experience_replay = 200
+ num_transitions_in_experience_replay = None
+ discount = 0.99
+ policy_gradient_rescaler = 'A_VALUE'
+ apply_gradients_every_x_episodes = 5
+ beta_entropy = 0
+ num_steps_between_gradient_updates = 20000 # t_max
+ num_steps_between_copying_online_weights_to_target = 1000
+ rate_for_copying_weights_to_target = 1.0
+ monte_carlo_mixing_rate = 0.1
+ gae_lambda = 0.96
+ step_until_collecting_full_episodes = False
+ targets_horizon = 'N-Step'
+ replace_mse_with_huber_loss = False
+
+ # PPO related params
+ target_kl_divergence = 0.01
+ initial_kl_coefficient = 1.0
+ high_kl_penalty_coefficient = 1000
+ value_targets_mix_fraction = 0.1
+ clip_likelihood_ratio_using_epsilon = None
+ use_kl_regularization = True
+ estimate_value_using_gae = False
+
+ # DFP related params
+ num_predicted_steps_ahead = 6
+ goal_vector = [1.0, 1.0]
+ future_measurements_weights = [0.5, 0.5, 1.0]
+
+ # NEC related params
+ dnd_size = 500000
+ l2_norm_added_delta = 0.001
+ new_value_shift_coefficient = 0.1
+ number_of_knn = 50
+ DND_key_error_threshold = 0.01
+
+ # Framework support
+ neon_support = False
+ tensorflow_support = True
+
+ # distributed agents params
+ shared_optimizer = True
+ share_statistics_between_workers = True
+
+
+class EnvironmentParameters:
+ type = 'Doom'
+ level = 'basic'
+ observation_stack_size = 4
+ frame_skip = 4
+ desired_observation_width = 76
+ desired_observation_height = 60
+ normalize_observation = False
+ reward_scaling = 1.0
+ reward_clipping_min = None
+ reward_clipping_max = None
+
+
+class ExplorationParameters:
+ # Exploration policies
+ policy = 'EGreedy'
+ evaluation_policy = 'Greedy'
+ # -- bootstrap dqn parameters
+ bootstrapped_data_sharing_probability = 0.5
+ architecture_num_q_heads = 1
+ # -- dropout approximation of thompson sampling parameters
+ dropout_discard_probability = 0
+ initial_keep_probability = 0.0 # unused
+ final_keep_probability = 0.99 # unused
+ keep_probability_decay_steps = 50000 # unused
+ # -- epsilon greedy parameters
+ initial_epsilon = 0.5
+ final_epsilon = 0.01
+ epsilon_decay_steps = 50000
+ evaluation_epsilon = 0.05
+ # -- epsilon greedy at end of episode parameters
+ average_episode_length_over_num_episodes = 20
+ # -- boltzmann softmax parameters
+ initial_temperature = 100.0
+ final_temperature = 1.0
+ temperature_decay_steps = 50000
+ # -- additive noise
+ initial_noise_variance_percentage = 0.1
+ final_noise_variance_percentage = 0.1
+ noise_variance_decay_steps = 1
+ # -- Ornstein-Uhlenbeck process
+ mu = 0
+ theta = 0.15
+ sigma = 0.3
+ dt = 0.01
+
+
+class GeneralParameters:
+ train = True
+ framework = Frameworks.TensorFlow
+ threads = 1
+ sess = None
+
+ # distributed training options
+ num_threads = 1
+ synchronize_over_num_threads = 1
+ distributed = False
+
+ # Agent blocks
+ memory = 'EpisodicExperienceReplay'
+ architecture = 'GeneralTensorFlowNetwork'
+
+ # General parameters
+ clip_gradients = None
+ kl_divergence_constraint = 100000
+ num_training_iterations = 10000000000
+ num_heatup_steps = 1000
+ batch_size = 32
+ save_model_sec = None
+ save_model_dir = None
+ checkpoint_restore_dir = None
+ learning_rate = 0.00025
+ learning_rate_decay_rate = 0
+ learning_rate_decay_steps = 0
+ evaluation_episodes = 5
+ evaluate_every_x_episodes = 1000000
+ rescaling_interpolation_type = 'bilinear'
+
+ # setting a seed will only work for non-parallel algorithms. Parallel algorithms add uncontrollable noise in
+ # the form of different workers starting at different times, and getting different assignments of CPU
+ # time from the OS.
+ seed = None
+
+ checkpoints_path = ''
+
+ # Testing parameters
+ test = False
+ test_min_return_threshold = 0
+ test_max_step_threshold = 1
+ test_num_workers = 1
+
+
+class VisualizationParameters:
+ # Visualization parameters
+ record_video_every = 1000
+ video_path = '/home/llt_lab/temp/breakout-videos'
+ plot_action_values_online = False
+ show_saliency_maps_every_num_episodes = 1000000000
+ print_summary = False
+ dump_csv = True
+ dump_signals_to_csv_every_x_episodes = 10
+ render = False
+ dump_gifs = True
+
+
+class Roboschool(EnvironmentParameters):
+ type = 'Gym'
+ frame_skip = 1
+ observation_stack_size = 1
+ desired_observation_height = None
+ desired_observation_width = None
+
+
+class GymVectorObservation(EnvironmentParameters):
+ type = 'Gym'
+ frame_skip = 1
+ observation_stack_size = 1
+ desired_observation_height = None
+ desired_observation_width = None
+
+
+class Bullet(EnvironmentParameters):
+ type = 'Bullet'
+ frame_skip = 1
+ observation_stack_size = 1
+ desired_observation_height = None
+ desired_observation_width = None
+
+
+class Atari(EnvironmentParameters):
+ type = 'Gym'
+ frame_skip = 1
+ observation_stack_size = 4
+ desired_observation_height = 84
+ desired_observation_width = 84
+ reward_clipping_max = 1.0
+ reward_clipping_min = -1.0
+
+
+class Doom(EnvironmentParameters):
+ type = 'Doom'
+ frame_skip = 4
+ observation_stack_size = 3
+ desired_observation_height = 60
+ desired_observation_width = 76
+
+
+class NStepQ(AgentParameters):
+ type = 'NStepQAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.Q]
+ loss_weights = [1.0]
+ optimizer_type = 'Adam'
+ num_steps_between_copying_online_weights_to_target = 1000
+ num_episodes_in_experience_replay = 2
+ apply_gradients_every_x_episodes = 1
+ num_steps_between_gradient_updates = 20 # this is called t_max in all the papers
+ hidden_layers_activation_function = 'elu'
+ targets_horizon = 'N-Step'
+ async_training = True
+ shared_optimizer = True
+
+
+class DQN(AgentParameters):
+ type = 'DQNAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.Q]
+ loss_weights = [1.0]
+ optimizer_type = 'Adam'
+ num_steps_between_copying_online_weights_to_target = 1000
+ neon_support = True
+ async_training = True
+ shared_optimizer = True
+
+
+class DDQN(DQN):
+ type = 'DDQNAgent'
+
+class DuelingDQN(DQN):
+ type = 'DQNAgent'
+ output_types = [OutputTypes.DuelingQ]
+
+class BootstrappedDQN(DQN):
+ type = 'BootstrappedDQNAgent'
+ num_output_head_copies = 10
+
+
+class DistributionalDQN(DQN):
+ type = 'DistributionalDQNAgent'
+ output_types = [OutputTypes.DistributionalQ]
+ v_min = -10.0
+ v_max = 10.0
+ atoms = 51
+
+
+class NEC(AgentParameters):
+ type = 'NECAgent'
+ optimizer_type = 'RMSProp'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.DNDQ]
+ loss_weights = [1.0]
+ dnd_size = 500000
+ l2_norm_added_delta = 0.001
+ new_value_shift_coefficient = 0.1
+ number_of_knn = 50
+ n_step = 100
+ bootstrap_total_return_from_old_policy = True
+ DND_key_error_threshold = 0.1
+
+
+class ActorCritic(AgentParameters):
+ type = 'ActorCriticAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.V, OutputTypes.Pi]
+ loss_weights = [0.5, 1.0]
+ stop_gradients_from_head = [False, False]
+ num_episodes_in_experience_replay = 2
+ policy_gradient_rescaler = 'A_VALUE'
+ hidden_layers_activation_function = 'elu'
+ apply_gradients_every_x_episodes = 5
+ beta_entropy = 0
+ num_steps_between_gradient_updates = 5000 # this is called t_max in all the papers
+ gae_lambda = 0.96
+ shared_optimizer = True
+ estimate_value_using_gae = False
+ async_training = True
+
+
+class PolicyGradient(AgentParameters):
+ type = 'PolicyGradientsAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.Pi]
+ loss_weights = [1.0]
+ num_episodes_in_experience_replay = 2
+ policy_gradient_rescaler = 'FUTURE_RETURN_NORMALIZED_BY_TIMESTEP'
+ apply_gradients_every_x_episodes = 5
+ beta_entropy = 0
+ num_steps_between_gradient_updates = 20000 # this is called t_max in all the papers
+ async_training = True
+
+
+class DDPG(AgentParameters):
+ type = 'DDPGAgent'
+ input_types = [InputTypes.Observation, InputTypes.Action]
+ output_types = [OutputTypes.V] # V is used because we only want a single Q value
+ loss_weights = [1.0]
+ hidden_layers_activation_function = 'relu'
+ num_episodes_in_experience_replay = 10000
+ num_steps_between_copying_online_weights_to_target = 1
+ rate_for_copying_weights_to_target = 0.001
+ shared_optimizer = True
+ async_training = True
+
+
+class DDDPG(AgentParameters):
+ type = 'DDPGAgent'
+ input_types = [InputTypes.Observation, InputTypes.Action]
+ output_types = [OutputTypes.V] # V is used because we only want a single Q value
+ loss_weights = [1.0]
+ hidden_layers_activation_function = 'relu'
+ num_episodes_in_experience_replay = 10000
+ num_steps_between_copying_online_weights_to_target = 10
+ rate_for_copying_weights_to_target = 1
+ shared_optimizer = True
+ async_training = True
+
+
+class NAF(AgentParameters):
+ type = 'NAFAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.NAF]
+ loss_weights = [1.0]
+ hidden_layers_activation_function = 'tanh'
+ num_consecutive_training_steps = 5
+ num_steps_between_copying_online_weights_to_target = 1
+ rate_for_copying_weights_to_target = 0.001
+ optimizer_type = 'RMSProp'
+ async_training = True
+
+
+class PPO(AgentParameters):
+ type = 'PPOAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.V]
+ loss_weights = [1.0]
+ hidden_layers_activation_function = 'tanh'
+ num_episodes_in_experience_replay = 1000000
+ policy_gradient_rescaler = 'A_VALUE'
+ gae_lambda = 0.96
+ target_kl_divergence = 0.01
+ initial_kl_coefficient = 1.0
+ high_kl_penalty_coefficient = 1000
+ add_a_normalized_timestep_to_the_observation = True
+ l2_regularization = 0#1e-3
+ value_targets_mix_fraction = 0.1
+ async_training = True
+ estimate_value_using_gae = True
+ step_until_collecting_full_episodes = True
+
+
+class ClippedPPO(AgentParameters):
+ type = 'ClippedPPOAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.V, OutputTypes.PPO]
+ loss_weights = [0.5, 1.0]
+ stop_gradients_from_head = [False, False]
+ hidden_layers_activation_function = 'tanh'
+ num_episodes_in_experience_replay = 1000000
+ policy_gradient_rescaler = 'GAE'
+ gae_lambda = 0.95
+ target_kl_divergence = 0.01
+ initial_kl_coefficient = 1.0
+ high_kl_penalty_coefficient = 1000
+ add_a_normalized_timestep_to_the_observation = False
+ l2_regularization = 1e-3
+ value_targets_mix_fraction = 0.1
+ clip_likelihood_ratio_using_epsilon = 0.2
+ async_training = False
+ use_kl_regularization = False
+ estimate_value_using_gae = True
+ batch_size = 64
+ use_separate_networks_per_head = True
+ step_until_collecting_full_episodes = True
+ beta_entropy = 0.01
+
+class DFP(AgentParameters):
+ type = 'DFPAgent'
+ input_types = [InputTypes.Observation, InputTypes.Measurements, InputTypes.GoalVector]
+ output_types = [OutputTypes.MeasurementsPrediction]
+ loss_weights = [1.0]
+ use_measurements = True
+ num_predicted_steps_ahead = 6
+ goal_vector = [1.0, 1.0]
+ future_measurements_weights = [0.5, 0.5, 1.0]
+ async_training = True
+
+
+class MMC(AgentParameters):
+ type = 'MixedMonteCarloAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.Q]
+ loss_weights = [1.0]
+ num_steps_between_copying_online_weights_to_target = 1000
+ monte_carlo_mixing_rate = 0.1
+ neon_support = True
+
+
+class PAL(AgentParameters):
+ type = 'PALAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.Q]
+ loss_weights = [1.0]
+ pal_alpha = 0.9
+ persistent_advantage_learning = False
+ num_steps_between_copying_online_weights_to_target = 1000
+ neon_support = True
+
+
+class EGreedyExploration(ExplorationParameters):
+ policy = 'EGreedy'
+ initial_epsilon = 0.5
+ final_epsilon = 0.01
+ epsilon_decay_steps = 50000
+ evaluation_epsilon = 0.05
+ initial_noise_variance_percentage = 0.1
+ final_noise_variance_percentage = 0.1
+ noise_variance_decay_steps = 50000
+
+
+class BootstrappedDQNExploration(ExplorationParameters):
+ policy = 'Bootstrapped'
+ architecture_num_q_heads = 10
+ bootstrapped_data_sharing_probability = 0.1
+
+
+class OUExploration(ExplorationParameters):
+ policy = 'OUProcess'
+ mu = 0
+ theta = 0.15
+ sigma = 0.3
+ dt = 0.01
+
+
+class AdditiveNoiseExploration(ExplorationParameters):
+ policy = 'AdditiveNoise'
+ initial_noise_variance_percentage = 0.1
+ final_noise_variance_percentage = 0.1
+ noise_variance_decay_steps = 50000
+
+
+class EntropyExploration(ExplorationParameters):
+ policy = 'ContinuousEntropy'
+
+
+class CategoricalExploration(ExplorationParameters):
+ policy = 'Categorical'
+
+
+class Preset(GeneralParameters):
+ def __init__(self, agent, env, exploration, visualization=VisualizationParameters):
+ """
+ :type agent: AgentParameters
+ :type env: EnvironmentParameters
+ :type exploration: ExplorationParameters
+ :type visualization: VisualizationParameters
+ """
+ self.visualization = visualization
+ self.agent = agent
+ self.env = env
+ self.exploration = exploration
+
diff --git a/dashboard.py b/dashboard.py
new file mode 100644
index 0000000..6d0a77f
--- /dev/null
+++ b/dashboard.py
@@ -0,0 +1,880 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+To run Coach Dashboard, run the following command:
+python dashboard.py
+"""
+
+from utils import *
+import os
+import datetime
+
+import sys
+import wx
+import random
+import pandas as pd
+from pandas.io.common import EmptyDataError
+import numpy as np
+from bokeh.palettes import Dark2
+from bokeh.layouts import row, column, widgetbox, Spacer
+from bokeh.models import ColumnDataSource, Range1d, LinearAxis, HoverTool, WheelZoomTool, PanTool
+from bokeh.models.widgets import RadioButtonGroup, MultiSelect, Button, Select, Slider, Div, CheckboxGroup
+from bokeh.models.glyphs import Patch
+from bokeh.plotting import figure, show, curdoc
+from utils import force_list
+from utils import squeeze_list
+from itertools import cycle
+from os import listdir
+from os.path import isfile, join, isdir, basename
+from enum import Enum
+
+
+class DialogApp(wx.App):
+ def getFileDialog(self):
+ with wx.FileDialog(None, "Open CSV file", wildcard="CSV files (*.csv)|*.csv",
+ style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR | wx.FD_MULTIPLE) as fileDialog:
+ if fileDialog.ShowModal() == wx.ID_CANCEL:
+ return None # the user changed their mind
+ else:
+ # Proceed loading the file chosen by the user
+ return fileDialog.GetPaths()
+
+ def getDirDialog(self):
+ with wx.DirDialog (None, "Choose input directory", "",
+ style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR) as dirDialog:
+ if dirDialog.ShowModal() == wx.ID_CANCEL:
+ return None # the user changed their mind
+ else:
+ # Proceed loading the dir chosen by the user
+ return dirDialog.GetPath()
+class Signal:
+ def __init__(self, name, parent):
+ self.name = name
+ self.full_name = "{}/{}".format(parent.filename, self.name)
+ self.selected = False
+ self.color = random.choice(Dark2[8])
+ self.line = None
+ self.bands = None
+ self.bokeh_source = parent.bokeh_source
+ self.min_val = 0
+ self.max_val = 0
+ self.axis = 'default'
+ self.sub_signals = []
+ for name in self.bokeh_source.data.keys():
+ if (len(name.split('/')) == 1 and name == self.name) or '/'.join(name.split('/')[:-1]) == self.name:
+ self.sub_signals.append(name)
+ if len(self.sub_signals) > 1:
+ self.mean_signal = squeeze_list([name for name in self.sub_signals if 'Mean' in name.split('/')[-1]])
+ self.stdev_signal = squeeze_list([name for name in self.sub_signals if 'Stdev' in name.split('/')[-1]])
+ self.min_signal = squeeze_list([name for name in self.sub_signals if 'Min' in name.split('/')[-1]])
+ self.max_signal = squeeze_list([name for name in self.sub_signals if 'Max' in name.split('/')[-1]])
+ else:
+ self.mean_signal = squeeze_list(self.name)
+ self.stdev_signal = None
+ self.min_signal = None
+ self.max_signal = None
+ self.has_bollinger_bands = False
+ if self.mean_signal and self.stdev_signal and self.min_signal and self.max_signal:
+ self.has_bollinger_bands = True
+ self.show_bollinger_bands = False
+ self.bollinger_bands_source = None
+ self.update_range()
+
+ def set_selected(self, val):
+ global current_color
+ if self.selected != val:
+ self.selected = val
+ if self.line:
+ self.color = Dark2[8][current_color]
+ current_color = (current_color + 1) % len(Dark2[8])
+ self.line.glyph.line_color = self.color
+ self.line.visible = self.selected
+ if self.bands:
+ self.bands.glyph.fill_color = self.color
+ self.bands.visible = self.selected and self.show_bollinger_bands
+ elif self.selected:
+ # lazy plotting - plot only when selected for the first time
+ show_spinner()
+ self.color = Dark2[8][current_color]
+ current_color = (current_color + 1) % len(Dark2[8])
+ if self.has_bollinger_bands:
+ self.set_bands_source()
+ self.create_bands()
+ self.line = plot.line('index', self.mean_signal, source=self.bokeh_source,
+ line_color=self.color, line_width=2)
+ self.line.visible = True
+ hide_spinner()
+
+ def set_dash(self, dash):
+ self.line.glyph.line_dash = dash
+
+ def create_bands(self):
+ self.bands = plot.patch(x='band_x', y='band_y', source=self.bollinger_bands_source,
+ color=self.color, fill_alpha=0.4, alpha=0.1, line_width=0)
+ self.bands.visible = self.show_bollinger_bands
+ # self.min_line = plot.line('index', self.min_signal, source=self.bokeh_source,
+ # line_color=self.color, line_width=3, line_dash="4 4")
+ # self.max_line = plot.line('index', self.max_signal, source=self.bokeh_source,
+ # line_color=self.color, line_width=3, line_dash="4 4")
+ # self.min_line.visible = self.show_bollinger_bands
+ # self.max_line.visible = self.show_bollinger_bands
+
+ def set_bands_source(self):
+ x_ticks = self.bokeh_source.data['index']
+ mean_values = self.bokeh_source.data[self.mean_signal]
+ stdev_values = self.bokeh_source.data[self.stdev_signal]
+ band_x = np.append(x_ticks, x_ticks[::-1])
+ band_y = np.append(mean_values - stdev_values, mean_values[::-1] + stdev_values[::-1])
+ source_data = {'band_x': band_x, 'band_y': band_y}
+ if self.bollinger_bands_source:
+ self.bollinger_bands_source.data = source_data
+ else:
+ self.bollinger_bands_source = ColumnDataSource(source_data)
+
+ def change_bollinger_bands_state(self, new_state):
+ self.show_bollinger_bands = new_state
+ if self.bands and self.selected:
+ self.bands.visible = new_state
+ # self.min_line.visible = new_state
+ # self.max_line.visible = new_state
+
+ def update_range(self):
+ self.min_val = np.min(self.bokeh_source.data[self.mean_signal])
+ self.max_val = np.max(self.bokeh_source.data[self.mean_signal])
+
+ def set_axis(self, axis):
+ self.axis = axis
+ self.line.y_range_name = axis
+
+ def toggle_axis(self):
+ if self.axis == 'default':
+ self.set_axis('secondary')
+ else:
+ self.set_axis('default')
+
+
+class SignalsFileBase:
+ def __init__(self):
+ self.full_csv_path = ""
+ self.dir = ""
+ self.filename = ""
+ self.signals_averaging_window = 1
+ self.show_bollinger_bands = False
+ self.csv = None
+ self.bokeh_source = None
+ self.bokeh_source_orig = None
+ self.last_modified = None
+ self.signals = {}
+ self.separate_files = False
+
+ def load_csv(self):
+ pass
+
+ def update_source_and_signals(self):
+ # create bokeh data sources
+ self.bokeh_source_orig = ColumnDataSource(self.csv)
+ self.bokeh_source_orig.data['index'] = self.bokeh_source_orig.data[x_axis]
+
+ if self.bokeh_source is None:
+ self.bokeh_source = ColumnDataSource(self.csv)
+ else:
+ # self.bokeh_source.data = self.bokeh_source_orig.data
+ # smooth the data if necessary
+ self.change_averaging_window(self.signals_averaging_window, force=True)
+
+ # create all the signals
+ if len(self.signals.keys()) == 0:
+ self.signals = {}
+ unique_signal_names = []
+ for name in self.csv.columns:
+ if len(name.split('/')) == 1:
+ unique_signal_names.append(name)
+ else:
+ unique_signal_names.append('/'.join(name.split('/')[:-1]))
+ unique_signal_names = list(set(unique_signal_names))
+ for signal_name in unique_signal_names:
+ self.signals[signal_name] = Signal(signal_name, self)
+
+ def load(self):
+ self.load_csv()
+ self.update_source_and_signals()
+
+ def reload_data(self, signals):
+ # this function is a workaround to reload the data of all the signals
+ # if the data doesn't change, bokeh does not refreshes the line
+ self.change_averaging_window(self.signals_averaging_window + 1, force=True)
+ self.change_averaging_window(self.signals_averaging_window - 1, force=True)
+
+ def change_averaging_window(self, new_size, force=False, signals=None):
+ if force or self.signals_averaging_window != new_size:
+ self.signals_averaging_window = new_size
+ win = np.ones(new_size) / new_size
+ temp_data = self.bokeh_source_orig.data.copy()
+ for col in self.bokeh_source.data.keys():
+ if col == 'index' or col in x_axis_options \
+ or (signals and not any(col in signal for signal in signals)):
+ temp_data[col] = temp_data[col][:-new_size]
+ continue
+ temp_data[col] = np.convolve(self.bokeh_source_orig.data[col], win, mode='same')[:-new_size]
+ self.bokeh_source.data = temp_data
+
+ # smooth bollinger bands
+ for signal in self.signals.values():
+ if signal.has_bollinger_bands:
+ signal.set_bands_source()
+
+ def hide_all_signals(self):
+ for signal_name in self.signals.keys():
+ self.set_signal_selection(signal_name, False)
+
+ def set_signal_selection(self, signal_name, val):
+ self.signals[signal_name].set_selected(val)
+
+ def change_bollinger_bands_state(self, new_state):
+ self.show_bollinger_bands = new_state
+ for signal in self.signals.values():
+ signal.change_bollinger_bands_state(new_state)
+
+ def file_was_modified_on_disk(self):
+ pass
+
+ def get_range_of_selected_signals_on_axis(self, axis):
+ max_val = -float('inf')
+ min_val = float('inf')
+ for signal in self.signals.values():
+ if signal.selected and signal.axis == axis:
+ max_val = max(max_val, signal.max_val)
+ min_val = min(min_val, signal.min_val)
+ return min_val, max_val
+
+ def get_selected_signals(self):
+ signals = []
+ for signal in self.signals.values():
+ if signal.selected:
+ signals.append(signal)
+ return signals
+
+ def show_files_separately(self, val):
+ pass
+
+
+class SignalsFile(SignalsFileBase):
+ def __init__(self, csv_path, load=True):
+ SignalsFileBase.__init__(self)
+ self.full_csv_path = csv_path
+ self.dir, self.filename, _ = break_file_path(csv_path)
+ if load:
+ self.load()
+
+ def load_csv(self):
+ # load csv and fix sparse data.
+ # csv can be in the middle of being written so we use try - except
+ self.csv = None
+ while self.csv is None:
+ try:
+ self.csv = pd.read_csv(self.full_csv_path)
+ break
+ except EmptyDataError:
+ self.csv = None
+ continue
+ self.csv = self.csv.interpolate()
+ self.csv.fillna(value=0, inplace=True)
+
+ self.last_modified = os.path.getmtime(self.full_csv_path)
+
+ def file_was_modified_on_disk(self):
+ return self.last_modified != os.path.getmtime(self.full_csv_path)
+
+
+class SignalsFilesGroup(SignalsFileBase):
+ def __init__(self, csv_paths):
+ SignalsFileBase.__init__(self)
+ self.full_csv_paths = csv_paths
+ self.signals_files = []
+ if len(csv_paths) == 1 and os.path.isdir(csv_paths[0]):
+ self.signals_files = [SignalsFile(str(file), load=False) for file in add_directory_csv_files(csv_paths[0])]
+ else:
+ for csv_path in csv_paths:
+ if os.path.isdir(csv_path):
+ self.signals_files.append(SignalsFilesGroup(add_directory_csv_files(csv_path)))
+ else:
+ self.signals_files.append(SignalsFile(str(csv_path), load=False))
+ self.dir = os.path.dirname(os.path.commonprefix(csv_paths))
+ self.filename = '{} - Group({})'.format(basename(self.dir), len(self.signals_files))
+ self.load()
+
+ def load_csv(self):
+ corrupted_files_idx = []
+ for idx, signal_file in enumerate(self.signals_files):
+ signal_file.load_csv()
+ if not all(option in signal_file.csv.keys() for option in x_axis_options):
+ print("Warning: {} file seems to be corrupted and does contain the necessary columns "
+ "and will not be rendered".format(signal_file.filename))
+ corrupted_files_idx.append(idx)
+
+ for file_idx in corrupted_files_idx:
+ del self.signals_files[file_idx]
+
+ # get the stats of all the columns
+ csv_group = pd.concat([signals_file.csv for signals_file in self.signals_files])
+ columns_to_remove = [s for s in csv_group.columns if '/Stdev' in s] + \
+ [s for s in csv_group.columns if '/Min' in s] + \
+ [s for s in csv_group.columns if '/Max' in s]
+ for col in columns_to_remove:
+ del csv_group[col]
+ csv_group = csv_group.groupby(csv_group.index)
+ self.csv_mean = csv_group.mean()
+ self.csv_mean.columns = [s + '/Mean' for s in self.csv_mean.columns]
+ self.csv_stdev = csv_group.std()
+ self.csv_stdev.columns = [s + '/Stdev' for s in self.csv_stdev.columns]
+ self.csv_min = csv_group.min()
+ self.csv_min.columns = [s + '/Min' for s in self.csv_min.columns]
+ self.csv_max = csv_group.max()
+ self.csv_max.columns = [s + '/Max' for s in self.csv_max.columns]
+
+ # get the indices from the file with the least number of indices and which is not an evaluation worker
+ file_with_min_indices = self.signals_files[0]
+ for signals_file in self.signals_files:
+ if signals_file.csv.shape[0] < file_with_min_indices.csv.shape[0] and \
+ 'Training reward' in signals_file.csv.keys():
+ file_with_min_indices = signals_file
+ self.index_columns = file_with_min_indices.csv[x_axis_options]
+
+ # concat the stats and the indices columns
+ num_rows = file_with_min_indices.csv.shape[0]
+ self.csv = pd.concat([self.index_columns, self.csv_mean.head(num_rows), self.csv_stdev.head(num_rows),
+ self.csv_min.head(num_rows), self.csv_max.head(num_rows)], axis=1)
+
+ # remove the stat columns for the indices columns
+ columns_to_remove = [s + '/Mean' for s in x_axis_options] + \
+ [s + '/Stdev' for s in x_axis_options] + \
+ [s + '/Min' for s in x_axis_options] + \
+ [s + '/Max' for s in x_axis_options]
+ for col in columns_to_remove:
+ del self.csv[col]
+
+ # remove NaNs
+ # self.csv.fillna(value=0, inplace=True)
+ for key in self.csv.keys():
+ if 'Stdev' in key and 'Evaluation' not in key:
+ self.csv[key] = self.csv[key].fillna(value=0)
+
+ for signal_file in self.signals_files:
+ signal_file.update_source_and_signals()
+
+ def change_averaging_window(self, new_size, force=False, signals=None):
+ for signal_file in self.signals_files:
+ signal_file.change_averaging_window(new_size, force, signals)
+ SignalsFileBase.change_averaging_window(self, new_size, force, signals)
+
+ def set_signal_selection(self, signal_name, val):
+ self.show_files_separately(self.separate_files)
+ SignalsFileBase.set_signal_selection(self, signal_name, val)
+
+ def file_was_modified_on_disk(self):
+ for signal_file in self.signals_files:
+ if signal_file.file_was_modified_on_disk():
+ return True
+ return False
+
+ def show_files_separately(self, val):
+ self.separate_files = val
+ for signal in self.signals.values():
+ if signal.selected:
+ if val:
+ signal.set_dash("4 4")
+ else:
+ signal.set_dash("")
+ for signal_file in self.signals_files:
+ try:
+ if val:
+ signal_file.set_signal_selection(signal.name, signal.selected)
+ else:
+ signal_file.set_signal_selection(signal.name, False)
+ except:
+ pass
+
+
+class RunType(Enum):
+ SINGLE_FOLDER_SINGLE_FILE = 1
+ SINGLE_FOLDER_MULTIPLE_FILES = 2
+ MULTIPLE_FOLDERS_SINGLE_FILES = 3
+ MULTIPLE_FOLDERS_MULTIPLE_FILES = 4
+ UNKNOWN = 0
+
+
+class FolderType(Enum):
+ SINGLE_FILE = 1
+ MULTIPLE_FILES = 2
+ MULTIPLE_FOLDERS = 3
+ EMPTY = 4
+
+dialog = DialogApp()
+
+# read data
+patches = {}
+signals_files = {}
+selected_file = None
+x_axis = 'Episode #'
+x_axis_options = ['Episode #', 'Total steps', 'Wall-Clock Time']
+current_color = 0
+
+# spinner
+root_dir = os.path.dirname(os.path.abspath(__file__))
+with open(os.path.join(root_dir, 'spinner.css'), 'r') as f:
+ spinner_style = """""".format(f.read())
+spinner_html = """"""
+spinner = Div(text="""""")
+
+# file refresh time placeholder
+refresh_info = Div(text="""""", width=210)
+
+# legend
+div = Div(text="""""")
+legend = widgetbox([div])
+
+# create figures
+plot = figure(plot_width=800, plot_height=800,
+ tools='pan,box_zoom,wheel_zoom,crosshair,reset,save',
+ toolbar_location='above', x_axis_label='Episodes')
+plot.extra_y_ranges = {"secondary": Range1d(start=-100, end=200)}
+plot.add_layout(LinearAxis(y_range_name="secondary"), 'right')
+
+
+def update_axis_range(name, range_placeholder):
+ max_val = -float('inf')
+ min_val = float('inf')
+ for signals_file in signals_files.values():
+ curr_min_val, curr_max_val = signals_file.get_range_of_selected_signals_on_axis(name)
+ max_val = max(max_val, curr_max_val)
+ min_val = min(min_val, curr_min_val)
+ if min_val != float('inf'):
+ range = max_val - min_val
+ range_placeholder.start = min_val - 0.1 * range
+ range_placeholder.end = max_val + 0.1 * range
+
+
+# update axes ranges
+def update_ranges():
+ update_axis_range('default', plot.y_range)
+ update_axis_range('secondary', plot.extra_y_ranges['secondary'])
+
+
+def get_all_selected_signals():
+ signals = []
+ for signals_file in signals_files.values():
+ signals += signals_file.get_selected_signals()
+ return signals
+
+
+# update legend using the legend text dictionary
+def update_legend():
+ legend_text = """"""
+ selected_signals = get_all_selected_signals()
+ for signal in selected_signals:
+ side_sign = "<" if signal.axis == 'default' else ">"
+ legend_text += """
Coach Dashboard
""")
+
+# landing page
+landing_page_description = Div(text="""Start by selecting an experiment file or directory to open:
""")
+center = Div(text="""""")
+center_buttons = Div(text="""""", width=0)
+landing_page = column(center,
+ title,
+ landing_page_description,
+ row(center_buttons),
+ row(file_selection_button, sizing_mode='scale_width'),
+ row(group_selection_button, sizing_mode='scale_width'),
+ sizing_mode='scale_width')
+
+# main layout of the document
+layout = row(file_selection_button, files_selector_spacer, group_selection_button, width=300)
+layout = column(layout, files_selector)
+layout = column(layout, row(refresh_info, unload_file_button))
+layout = column(layout, data_selector)
+layout = column(layout, x_axis_selector_title)
+layout = column(layout, x_axis_selector)
+layout = column(layout, group_cb)
+layout = column(layout, toggle_second_axis_button)
+layout = column(layout, averaging_slider)
+layout = column(layout, legend)
+layout = row(layout, plot)
+layout = column(title, layout)
+layout = column(layout, spinner)
+
+doc = curdoc()
+doc.add_root(landing_page)
+
+doc.add_periodic_callback(reload_all_files, 20000)
+plot.y_range = Range1d(0, 100)
+plot.extra_y_ranges['secondary'] = Range1d(0, 100)
+
+# show load file dialog immediately on start
+#doc.add_timeout_callback(load_files, 1000)
+
+if __name__ == "__main__":
+ # find an open port and run the server
+ import socket
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ port = 12345
+ while True:
+ try:
+ s.bind(("127.0.0.1", port))
+ break
+ except socket.error as e:
+ if e.errno == 98:
+ port += 1
+ s.close()
+ os.system('bokeh serve --show dashboard.py --port {}'.format(port))
diff --git a/debug_utils.py b/debug_utils.py
new file mode 100644
index 0000000..11db9fc
--- /dev/null
+++ b/debug_utils.py
@@ -0,0 +1,50 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def show_observation_stack(stack, channels_last=False):
+ if isinstance(stack, list): # is list
+ stack_size = len(stack)
+ elif len(stack.shape) == 3:
+ stack_size = stack.shape[0] # is numpy array
+ elif len(stack.shape) == 4:
+ stack_size = stack.shape[1] # ignore batch dimension
+ stack = stack[0]
+ else:
+ assert False, ""
+
+ if channels_last:
+ stack = np.transpose(stack, (2, 0, 1))
+ stack_size = stack.shape[0]
+
+ for i in range(stack_size):
+ plt.subplot(1, stack_size, i + 1)
+ plt.imshow(stack[i], cmap='gray')
+
+ plt.show()
+
+
+def show_diff_between_two_observations(observation1, observation2):
+ plt.imshow(observation1 - observation2, cmap='gray')
+ plt.show()
+
+
+def plot_grayscale_observation(observation):
+ plt.imshow(observation, cmap='gray')
+ plt.show()
diff --git a/docs/README.txt b/docs/README.txt
new file mode 100644
index 0000000..a60dd37
--- /dev/null
+++ b/docs/README.txt
@@ -0,0 +1,14 @@
+installation
+=============
+1. install mkdocs by following the instructions here -
+ http://www.mkdocs.org/#installation
+2. install the math extension for mkdocs
+ sudo -E pip install python-markdown-math
+3. install the material theme
+ sudo -E pip install mkdocs-material
+
+to build the documentation website run:
+- mkdocs build
+- python fix_index.py
+
+this will create a folder named site which contains the documentation website
diff --git a/docs/docs/algorithms/design_imgs/ac.png b/docs/docs/algorithms/design_imgs/ac.png
new file mode 100644
index 0000000..1b9b87d
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/ac.png differ
diff --git a/docs/docs/algorithms/design_imgs/bs_dqn.png b/docs/docs/algorithms/design_imgs/bs_dqn.png
new file mode 100644
index 0000000..36a4b44
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/bs_dqn.png differ
diff --git a/docs/docs/algorithms/design_imgs/ddpg.png b/docs/docs/algorithms/design_imgs/ddpg.png
new file mode 100644
index 0000000..da1e597
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/ddpg.png differ
diff --git a/docs/docs/algorithms/design_imgs/dfp.png b/docs/docs/algorithms/design_imgs/dfp.png
new file mode 100644
index 0000000..85e356c
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/dfp.png differ
diff --git a/docs/docs/algorithms/design_imgs/distributional_dqn.png b/docs/docs/algorithms/design_imgs/distributional_dqn.png
new file mode 100644
index 0000000..83f5e50
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/distributional_dqn.png differ
diff --git a/docs/docs/algorithms/design_imgs/dqn.png b/docs/docs/algorithms/design_imgs/dqn.png
new file mode 100644
index 0000000..9b79101
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/dqn.png differ
diff --git a/docs/docs/algorithms/design_imgs/dueling_dqn.png b/docs/docs/algorithms/design_imgs/dueling_dqn.png
new file mode 100644
index 0000000..d2f4b0f
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/dueling_dqn.png differ
diff --git a/docs/docs/algorithms/design_imgs/naf.png b/docs/docs/algorithms/design_imgs/naf.png
new file mode 100644
index 0000000..c7ba1bc
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/naf.png differ
diff --git a/docs/docs/algorithms/design_imgs/nec.png b/docs/docs/algorithms/design_imgs/nec.png
new file mode 100644
index 0000000..26c8237
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/nec.png differ
diff --git a/docs/docs/algorithms/design_imgs/pg.png b/docs/docs/algorithms/design_imgs/pg.png
new file mode 100644
index 0000000..4b47883
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/pg.png differ
diff --git a/docs/docs/algorithms/design_imgs/ppo.png b/docs/docs/algorithms/design_imgs/ppo.png
new file mode 100644
index 0000000..2d128d8
Binary files /dev/null and b/docs/docs/algorithms/design_imgs/ppo.png differ
diff --git a/docs/docs/algorithms/other/dfp.md b/docs/docs/algorithms/other/dfp.md
new file mode 100644
index 0000000..0b8985a
--- /dev/null
+++ b/docs/docs/algorithms/other/dfp.md
@@ -0,0 +1,23 @@
+> Actions space: Discrete
+
+[Paper](https://arxiv.org/abs/1611.01779)
+
+## Network Structure
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Displaying Bollinger Bands
+
+
+
+ Displaying All The Workers
+
+ + + + + +
+
+
+
+
+Comparing Several Algorithms According to the Time Passed
+
+
+ |
+
+
+
+
+
+Comparing Several Algorithms According to the Number of Episodes Played
+
+
+ |
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/docs/mdx_math.py b/docs/docs/mdx_math.py
new file mode 100644
index 0000000..fe28d11
--- /dev/null
+++ b/docs/docs/mdx_math.py
@@ -0,0 +1,80 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# -*- coding: utf-8 -*-
+
+'''
+Math extension for Python-Markdown
+==================================
+
+Adds support for displaying math formulas using [MathJax](http://www.mathjax.org/).
+
+Author: 2015, Dmitry Shachnev