1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-20 06:33:31 +02:00

ACER algorithm (#184)

* initial ACER commit

* Code cleanup + several fixes

* Q-retrace bug fix + small clean-ups

* added documentation for acer

* ACER benchmarks

* update benchmarks table

* Add nightly running of golden and trace tests. (#202)

Resolves #200

* comment out nightly trace tests until values reset.

* remove redundant observe ignore (#168)

* ensure nightly test env containers exist. (#205)

Also bump integration test timeout

* wxPython removal (#207)

Replacing wxPython with Python's Tkinter.
Also removing the option to choose multiple files as it is unused and causes errors, and fixing the load file/directory spinner.

* Create CONTRIBUTING.md (#210)

* Create CONTRIBUTING.md.  Resolves #188

* run nightly golden tests sequentially. (#217)

Should reduce resource requirements and potential CPU contention but increases
overall execution time.

* tests: added new setup configuration + test args (#211)

- added utils for future tests and conftest
- added test args

* new docs build

* golden test update
This commit is contained in:
shadiendrawis
2019-02-20 23:52:34 +02:00
committed by GitHub
parent 7253f511ed
commit 2b5d1dabe6
175 changed files with 2327 additions and 664 deletions
+1
View File
@@ -34,6 +34,7 @@ The environments that were used for testing include:
|**[Bootstrapped DQN](bootstrapped_dqn)**| ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Atari | |
|**[QR-DQN](qr_dqn)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Atari | |
|**[A3C](a3c)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Atari, Mujoco | |
|**[ACER](acer)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Atari | |
|**[Clipped PPO](clipped_ppo)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[DDPG](ddpg)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[NEC](nec)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Atari | |
+28
View File
@@ -0,0 +1,28 @@
# ACER
Each experiment uses 3 seeds.
The parameters used for ACER are the same parameters as described in the [original paper](https://arxiv.org/abs/1611.01224), except for the optimizer (changed to ADAM) and learning rate (1e-4) used.
### Breakout ACER - 16 workers
```bash
coach -p Atari_ACER -lvl breakout -n 16
```
<img src="breakout_acer_16_workers.png" alt="Breakout ACER" width="800"/>
### Space Invaders ACER - 16 workers
```bash
coach -p Atari_ACER -lvl space_invaders -n 16
```
<img src="space_invaders_acer_16_workers.png" alt="Space Invaders ACER" width="800"/>
### Pong ACER - 16 workers
```bash
coach -p Atari_ACER -lvl pong -n 16
```
<img src="pong_acer_16_workers.png" alt="Pong ACER" width="800"/>
Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 50 KiB

After

Width:  |  Height:  |  Size: 51 KiB

+4 -2
View File
@@ -176,7 +176,8 @@
<div itemprop="articleBody">
<h1>All modules for which code is available</h1>
<ul><li><a href="rl_coach/agents/actor_critic_agent.html">rl_coach.agents.actor_critic_agent</a></li>
<ul><li><a href="rl_coach/agents/acer_agent.html">rl_coach.agents.acer_agent</a></li>
<li><a href="rl_coach/agents/actor_critic_agent.html">rl_coach.agents.actor_critic_agent</a></li>
<li><a href="rl_coach/agents/agent.html">rl_coach.agents.agent</a></li>
<li><a href="rl_coach/agents/bc_agent.html">rl_coach.agents.bc_agent</a></li>
<li><a href="rl_coach/agents/categorical_dqn_agent.html">rl_coach.agents.categorical_dqn_agent</a></li>
@@ -288,7 +289,8 @@
<script type="text/javascript" src="../_static/jquery.js"></script>
<script type="text/javascript" src="../_static/underscore.js"></script>
<script type="text/javascript" src="../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -0,0 +1,431 @@
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>rl_coach.agents.acer_agent &mdash; Reinforcement Learning Coach 0.11.0 documentation</title>
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
<script src="../../../_static/js/modernizr.min.js"></script>
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search">
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Intro</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
</ul>
<p class="caption"><span class="caption-text">Design</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
</ul>
<p class="caption"><span class="caption-text">Contributing</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
</ul>
<p class="caption"><span class="caption-text">Components</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/data_stores/index.html">Data Stores</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memory_backends/index.html">Memory Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/orchestrators/index.html">Orchestrators</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">Reinforcement Learning Coach</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html">Docs</a> &raquo;</li>
<li><a href="../../index.html">Module code</a> &raquo;</li>
<li>rl_coach.agents.acer_agent</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for rl_coach.agents.acer_agent</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.agents.policy_optimization_agent</span> <span class="k">import</span> <span class="n">PolicyOptimizationAgent</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">ACERPolicyHeadParameters</span><span class="p">,</span> <span class="n">QHeadParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AgentParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">Batch</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.categorical</span> <span class="k">import</span> <span class="n">CategoricalParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span>
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">eps</span><span class="p">,</span> <span class="n">last_sample</span>
<div class="viewcode-block" id="ACERAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/acer.html#rl_coach.agents.acer_agent.ACERAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">ACERAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param num_steps_between_gradient_updates: (int)</span>
<span class="sd"> Every num_steps_between_gradient_updates transitions will be considered as a single batch and use for</span>
<span class="sd"> accumulating gradients. This is also the number of steps used for bootstrapping according to the n-step formulation.</span>
<span class="sd"> :param ratio_of_replay: (int)</span>
<span class="sd"> The number of off-policy training iterations in each ACER iteration.</span>
<span class="sd"> :param num_transitions_to_start_replay: (int)</span>
<span class="sd"> Number of environment steps until ACER starts to train off-policy from the experience replay.</span>
<span class="sd"> This emulates a heat-up phase where the agents learns only on-policy until there are enough transitions in</span>
<span class="sd"> the experience replay to start the off-policy training.</span>
<span class="sd"> :param rate_for_copying_weights_to_target: (float)</span>
<span class="sd"> The rate of the exponential moving average for the average policy which is used for the trust region optimization.</span>
<span class="sd"> The target network in this algorithm is used as the average policy.</span>
<span class="sd"> :param importance_weight_truncation: (float)</span>
<span class="sd"> The clipping constant for the importance weight truncation (not used in the Q-retrace calculation).</span>
<span class="sd"> :param use_trust_region_optimization: (bool)</span>
<span class="sd"> If set to True, the gradients of the network will be modified with a term dependant on the KL divergence between</span>
<span class="sd"> the average policy and the current one, to bound the change of the policy during the network update.</span>
<span class="sd"> :param max_KL_divergence: (float)</span>
<span class="sd"> The upper bound parameter for the trust region optimization, use_trust_region_optimization needs to be set true</span>
<span class="sd"> for this parameter to have an effect.</span>
<span class="sd"> :param beta_entropy: (float)</span>
<span class="sd"> An entropy regulaization term can be added to the loss function in order to control exploration. This term</span>
<span class="sd"> is weighted using the beta value defined by beta_entropy.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_every_x_episodes</span> <span class="o">=</span> <span class="mi">5</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_gradient_updates</span> <span class="o">=</span> <span class="mi">5000</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ratio_of_replay</span> <span class="o">=</span> <span class="mi">4</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_to_start_replay</span> <span class="o">=</span> <span class="mi">10000</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.99</span>
<span class="bp">self</span><span class="o">.</span><span class="n">importance_weight_truncation</span> <span class="o">=</span> <span class="mf">10.0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">use_trust_region_optimization</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_KL_divergence</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beta_entropy</span> <span class="o">=</span> <span class="mi">0</span></div>
<span class="k">class</span> <span class="nc">ACERNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">QHeadParameters</span><span class="p">(</span><span class="n">loss_weight</span><span class="o">=</span><span class="mf">0.5</span><span class="p">),</span> <span class="n">ACERPolicyHeadParameters</span><span class="p">(</span><span class="n">loss_weight</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clip_gradients</span> <span class="o">=</span> <span class="mf">40.0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">class</span> <span class="nc">ACERAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">ACERAlgorithmParameters</span><span class="p">(),</span>
<span class="n">exploration</span><span class="o">=</span><span class="p">{</span><span class="n">DiscreteActionSpace</span><span class="p">:</span> <span class="n">CategoricalParameters</span><span class="p">()},</span>
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;main&quot;</span><span class="p">:</span> <span class="n">ACERNetworkParameters</span><span class="p">()})</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="s1">&#39;rl_coach.agents.acer_agent:ACERAgent&#39;</span>
<span class="c1"># Actor-Critic with Experience Replay - https://arxiv.org/abs/1611.01224</span>
<span class="k">class</span> <span class="nc">ACERAgent</span><span class="p">(</span><span class="n">PolicyOptimizationAgent</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">&#39;LevelManager&#39;</span><span class="p">,</span> <span class="s1">&#39;CompositeAgent&#39;</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
<span class="c1"># signals definition</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Q Loss&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Policy Loss&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">probability_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Probability Loss&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias_correction_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Bias Correction Loss&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">unclipped_grads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Grads (unclipped)&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">V_Values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Values&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kl_divergence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;KL Divergence&#39;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="n">fetches</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">probability_loss</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">bias_correction_loss</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">kl_divergence</span><span class="p">]</span>
<span class="c1"># batch contains a list of transitions to learn from</span>
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="c1"># get the values for the current states</span>
<span class="n">Q_values</span><span class="p">,</span> <span class="n">policy_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
<span class="n">avg_policy_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">current_state_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">policy_prob</span> <span class="o">*</span> <span class="n">Q_values</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">actions</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span>
<span class="n">num_transitions</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">size</span>
<span class="n">Q_head_targets</span> <span class="o">=</span> <span class="n">Q_values</span>
<span class="n">Q_i</span> <span class="o">=</span> <span class="n">Q_values</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_transitions</span><span class="p">),</span> <span class="n">actions</span><span class="p">]</span>
<span class="n">mu</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;all_action_probabilities&#39;</span><span class="p">)</span>
<span class="n">rho</span> <span class="o">=</span> <span class="n">policy_prob</span> <span class="o">/</span> <span class="p">(</span><span class="n">mu</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
<span class="n">rho_i</span> <span class="o">=</span> <span class="n">rho</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">actions</span><span class="p">]</span>
<span class="n">rho_bar</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">rho_i</span><span class="p">)</span>
<span class="k">if</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
<span class="n">Qret</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">last_sample</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)))</span>
<span class="n">Qret</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">result</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_transitions</span><span class="p">)):</span>
<span class="n">Qret</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">Qret</span>
<span class="n">Q_head_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Qret</span>
<span class="n">Qret</span> <span class="o">=</span> <span class="n">rho_bar</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">Qret</span> <span class="o">-</span> <span class="n">Q_i</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">+</span> <span class="n">current_state_values</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">Q_retrace</span> <span class="o">=</span> <span class="n">Q_head_targets</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_transitions</span><span class="p">),</span> <span class="n">actions</span><span class="p">]</span>
<span class="c1"># train</span>
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">({</span><span class="o">**</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span>
<span class="s1">&#39;output_1_0&#39;</span><span class="p">:</span> <span class="n">actions</span><span class="p">,</span>
<span class="s1">&#39;output_1_1&#39;</span><span class="p">:</span> <span class="n">rho</span><span class="p">,</span>
<span class="s1">&#39;output_1_2&#39;</span><span class="p">:</span> <span class="n">rho_i</span><span class="p">,</span>
<span class="s1">&#39;output_1_3&#39;</span><span class="p">:</span> <span class="n">Q_values</span><span class="p">,</span>
<span class="s1">&#39;output_1_4&#39;</span><span class="p">:</span> <span class="n">Q_retrace</span><span class="p">,</span>
<span class="s1">&#39;output_1_5&#39;</span><span class="p">:</span> <span class="n">avg_policy_prob</span><span class="p">},</span>
<span class="p">[</span><span class="n">Q_head_targets</span><span class="p">,</span> <span class="n">current_state_values</span><span class="p">],</span>
<span class="n">additional_fetches</span><span class="o">=</span><span class="n">fetches</span><span class="p">)</span>
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">network</span><span class="o">.</span><span class="n">update_target_network</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span><span class="p">)</span>
<span class="c1"># logging</span>
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span><span class="p">,</span> <span class="n">fetch_result</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">4</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">losses</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">losses</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">probability_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">fetch_result</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias_correction_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">fetch_result</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">unclipped_grads</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">unclipped_grads</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">V_Values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">current_state_values</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kl_divergence</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">fetch_result</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="c1"># perform on-policy training iteration</span>
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_learn_from_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">ratio_of_replay</span> <span class="o">&gt;</span> <span class="mi">0</span> \
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">num_transitions</span><span class="p">()</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_transitions_to_start_replay</span><span class="p">:</span>
<span class="n">n</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">poisson</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">ratio_of_replay</span><span class="p">)</span>
<span class="c1"># perform n off-policy training iterations</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
<span class="n">new_batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;sample&#39;</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_steps_between_gradient_updates</span><span class="p">,</span> <span class="kc">True</span><span class="p">)))</span>
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_learn_from_batch</span><span class="p">(</span><span class="n">new_batch</span><span class="p">)</span>
<span class="n">total_loss</span> <span class="o">+=</span> <span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">losses</span> <span class="o">+=</span> <span class="n">result</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">unclipped_grads</span> <span class="o">+=</span> <span class="n">result</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s2">&quot;main&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)[</span><span class="mi">1</span><span class="p">:]</span> <span class="c1"># index 0 is the state value</span>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2018, Intel AI Lab
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>
@@ -401,7 +401,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+20 -22
View File
@@ -789,35 +789,35 @@
<span class="sd"> :return: boolean: True if we should start a training phase</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">should_update</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_train_helper</span><span class="p">()</span>
<span class="n">should_update</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_update</span><span class="p">()</span>
<span class="n">step_method</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span>
<span class="n">steps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span>
<span class="k">if</span> <span class="n">should_update</span><span class="p">:</span>
<span class="k">if</span> <span class="n">step_method</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">EnvironmentEpisodes</span><span class="p">:</span>
<span class="k">if</span> <span class="n">steps</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">EnvironmentEpisodes</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_phase_step</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span>
<span class="k">if</span> <span class="n">step_method</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">EnvironmentSteps</span><span class="p">:</span>
<span class="k">if</span> <span class="n">steps</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">EnvironmentSteps</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_phase_step</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span>
<span class="k">return</span> <span class="n">should_update</span>
<span class="k">def</span> <span class="nf">_should_train_helper</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">_should_update</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">wait_for_full_episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">act_for_full_episodes</span>
<span class="n">step_method</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span>
<span class="n">steps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span>
<span class="k">if</span> <span class="n">step_method</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">EnvironmentEpisodes</span><span class="p">:</span>
<span class="n">should_update</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_phase_step</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="n">step_method</span><span class="o">.</span><span class="n">num_steps</span>
<span class="k">if</span> <span class="n">steps</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">EnvironmentEpisodes</span><span class="p">:</span>
<span class="n">should_update</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_phase_step</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="n">steps</span><span class="o">.</span><span class="n">num_steps</span>
<span class="n">should_update</span> <span class="o">=</span> <span class="n">should_update</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;length&#39;</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="k">elif</span> <span class="n">step_method</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">EnvironmentSteps</span><span class="p">:</span>
<span class="n">should_update</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_phase_step</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="n">step_method</span><span class="o">.</span><span class="n">num_steps</span>
<span class="k">elif</span> <span class="n">steps</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">EnvironmentSteps</span><span class="p">:</span>
<span class="n">should_update</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_phase_step</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="n">steps</span><span class="o">.</span><span class="n">num_steps</span>
<span class="n">should_update</span> <span class="o">=</span> <span class="n">should_update</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;num_transitions&#39;</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">wait_for_full_episode</span><span class="p">:</span>
<span class="n">should_update</span> <span class="o">=</span> <span class="n">should_update</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">is_complete</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The num_consecutive_playing_steps parameter should be either &quot;</span>
<span class="s2">&quot;EnvironmentSteps or Episodes. Instead it is </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">step_method</span><span class="o">.</span><span class="vm">__class__</span><span class="p">))</span>
<span class="s2">&quot;EnvironmentSteps or Episodes. Instead it is </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">steps</span><span class="o">.</span><span class="vm">__class__</span><span class="p">))</span>
<span class="k">return</span> <span class="n">should_update</span>
@@ -942,7 +942,8 @@
<span class="c1"># informed action</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># before choosing an action, first use the pre_network_filter to filter out the current state</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">)</span>
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="n">update_filter_internal_state</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span>
@@ -952,15 +953,18 @@
<span class="k">return</span> <span class="n">filtered_action_info</span></div>
<div class="viewcode-block" id="Agent.run_pre_network_filter_for_inference"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference">[docs]</a> <span class="k">def</span> <span class="nf">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">StateType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">StateType</span><span class="p">:</span>
<div class="viewcode-block" id="Agent.run_pre_network_filter_for_inference"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference">[docs]</a> <span class="k">def</span> <span class="nf">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">update_filter_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>\
<span class="o">-&gt;</span> <span class="n">StateType</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Run filters which where defined for being applied right before using the state for inference.</span>
<span class="sd"> :param state: The state to run the filters on</span>
<span class="sd"> :param update_filter_internal_state: Should update the filter&#39;s internal state - should not update when evaluating</span>
<span class="sd"> :return: The filtered state</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span></div>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span>
<span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_filter_internal_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span></div>
<div class="viewcode-block" id="Agent.get_state_embedding"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.get_state_embedding">[docs]</a> <span class="k">def</span> <span class="nf">get_state_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
@@ -1153,13 +1157,6 @@
<span class="sd"> :return:</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># if we are in the first step in the episode, then we don&#39;t have a a next state and a reward and thus no</span>
<span class="c1"># transition yet, and therefore we don&#39;t need to store anything in the memory.</span>
<span class="c1"># also we did not reach the goal yet.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_steps_counter</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># initialize the current state</span>
<span class="k">return</span> <span class="n">transition</span><span class="o">.</span><span class="n">game_over</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># sum up the total shaped reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_shaped_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
@@ -1254,7 +1251,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+2 -1
View File
@@ -296,7 +296,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -370,7 +370,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+2 -1
View File
@@ -302,7 +302,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -505,7 +505,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">update_log</span><span class="p">()</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>
@@ -550,7 +550,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -431,7 +431,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+4 -2
View File
@@ -264,7 +264,8 @@
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param num_predicted_steps_ahead: (int)</span>
<span class="sd"> Number of future steps to predict measurements for. The future steps won&#39;t be sequential, but rather jump</span>
<span class="sd"> in multiples of 2. For example, if num_predicted_steps_ahead = 3, then the steps will be: t+1, t+2, t+4</span>
<span class="sd"> in multiples of 2. For example, if num_predicted_steps_ahead = 3, then the steps will be: t+1, t+2, t+4.</span>
<span class="sd"> The predicted steps will be [t + 2**i for i in range(num_predicted_steps_ahead)]</span>
<span class="sd"> :param goal_vector: (List[float])</span>
<span class="sd"> The goal vector will weight each of the measurements to form an optimization goal. The vector should have</span>
@@ -463,7 +464,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+2 -1
View File
@@ -315,7 +315,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+2 -1
View File
@@ -294,7 +294,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -361,7 +361,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+2 -1
View File
@@ -342,7 +342,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+2 -1
View File
@@ -424,7 +424,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+2 -1
View File
@@ -322,7 +322,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -344,7 +344,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+2 -1
View File
@@ -607,7 +607,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -335,7 +335,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -347,7 +347,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -313,7 +313,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -452,7 +452,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -201,17 +201,6 @@
<span class="kn">from</span> <span class="nn">rl_coach.saver</span> <span class="k">import</span> <span class="n">SaverCollection</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">SpacesDefinition</span>
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">force_list</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.tensorflow_components.general_network</span> <span class="k">import</span> <span class="n">GeneralTensorFlowNetwork</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="n">failed_imports</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot;tensorflow&quot;</span><span class="p">)</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">import</span> <span class="nn">mxnet</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.mxnet_components.general_network</span> <span class="k">import</span> <span class="n">GeneralMxnetNetwork</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="n">failed_imports</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot;mxnet&quot;</span><span class="p">)</span>
<div class="viewcode-block" id="NetworkWrapper"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper">[docs]</a><span class="k">class</span> <span class="nc">NetworkWrapper</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
@@ -233,15 +222,19 @@
<span class="bp">self</span><span class="o">.</span><span class="n">sess</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">network_parameters</span><span class="o">.</span><span class="n">framework</span> <span class="o">==</span> <span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span><span class="p">:</span>
<span class="k">if</span> <span class="s2">&quot;tensorflow&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">failed_imports</span><span class="p">:</span>
<span class="n">general_network</span> <span class="o">=</span> <span class="n">GeneralTensorFlowNetwork</span><span class="o">.</span><span class="n">construct</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Install tensorflow before using it as framework&#39;</span><span class="p">)</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.tensorflow_components.general_network</span> <span class="k">import</span> <span class="n">GeneralTensorFlowNetwork</span>
<span class="n">general_network</span> <span class="o">=</span> <span class="n">GeneralTensorFlowNetwork</span><span class="o">.</span><span class="n">construct</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">network_parameters</span><span class="o">.</span><span class="n">framework</span> <span class="o">==</span> <span class="n">Frameworks</span><span class="o">.</span><span class="n">mxnet</span><span class="p">:</span>
<span class="k">if</span> <span class="s2">&quot;mxnet&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">failed_imports</span><span class="p">:</span>
<span class="n">general_network</span> <span class="o">=</span> <span class="n">GeneralMxnetNetwork</span><span class="o">.</span><span class="n">construct</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">import</span> <span class="nn">mxnet</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Install mxnet before using it as framework&#39;</span><span class="p">)</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.mxnet_components.general_network</span> <span class="k">import</span> <span class="n">GeneralMxnetNetwork</span>
<span class="n">general_network</span> <span class="o">=</span> <span class="n">GeneralMxnetNetwork</span><span class="o">.</span><span class="n">construct</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="s2"> Framework is not supported&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">Frameworks</span><span class="p">()</span><span class="o">.</span><span class="n">to_string</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">network_parameters</span><span class="o">.</span><span class="n">framework</span><span class="p">)))</span>
@@ -475,7 +468,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+2 -1
View File
@@ -839,7 +839,8 @@
<script type="text/javascript" src="../../_static/jquery.js"></script>
<script type="text/javascript" src="../../_static/underscore.js"></script>
<script type="text/javascript" src="../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+9 -8
View File
@@ -278,6 +278,10 @@
<span class="k">pass</span>
<span class="k">class</span> <span class="nc">Measurements</span><span class="p">(</span><span class="n">PredictionType</span><span class="p">):</span>
<span class="k">pass</span>
<span class="k">class</span> <span class="nc">InputEmbedding</span><span class="p">(</span><span class="n">Embedding</span><span class="p">):</span>
<span class="k">pass</span>
@@ -306,10 +310,6 @@
<span class="k">pass</span>
<span class="k">class</span> <span class="nc">Measurements</span><span class="p">(</span><span class="n">PredictionType</span><span class="p">):</span>
<span class="k">pass</span>
<span class="n">PlayingStepsType</span> <span class="o">=</span> <span class="n">Union</span><span class="p">[</span><span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">EnvironmentEpisodes</span><span class="p">,</span> <span class="n">Frames</span><span class="p">]</span>
@@ -509,12 +509,12 @@
<span class="sd"> Action info is a class that holds an action and various additional information details about it</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">,</span> <span class="n">action_probability</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">,</span> <span class="n">all_action_probabilities</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">state_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">max_action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">action_intrinsic_reward</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param action: the action</span>
<span class="sd"> :param action_probability: the probability that the action was given when selecting it</span>
<span class="sd"> :param all_action_probabilities: the probability that the action was given when selecting it</span>
<span class="sd"> :param action_value: the state-action value (Q value) of the action</span>
<span class="sd"> :param state_value: the state value (V value) of the state where the action was taken</span>
<span class="sd"> :param max_action_value: in case this is an action that was selected randomly, this is the value of the action</span>
@@ -524,7 +524,7 @@
<span class="sd"> selection</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action</span> <span class="o">=</span> <span class="n">action</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action_probability</span> <span class="o">=</span> <span class="n">action_probability</span>
<span class="bp">self</span><span class="o">.</span><span class="n">all_action_probabilities</span> <span class="o">=</span> <span class="n">all_action_probabilities</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action_value</span> <span class="o">=</span> <span class="n">action_value</span>
<span class="bp">self</span><span class="o">.</span><span class="n">state_value</span> <span class="o">=</span> <span class="n">state_value</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">max_action_value</span><span class="p">:</span>
@@ -1084,7 +1084,8 @@
<script type="text/javascript" src="../../_static/jquery.js"></script>
<script type="text/javascript" src="../../_static/underscore.js"></script>
<script type="text/javascript" src="../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -501,7 +501,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -384,7 +384,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -683,7 +683,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -414,7 +414,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -483,7 +483,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -240,10 +240,11 @@
<span class="n">logger</span><span class="o">.</span><span class="n">screen</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">&quot;No level has been selected. Please select a level using the -lvl command line flag, &quot;</span>
<span class="s2">&quot;or change the level in the preset. </span><span class="se">\n</span><span class="s2">The available levels are: </span><span class="se">\n</span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">sorted</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">levels</span><span class="o">.</span><span class="n">keys</span><span class="p">()))),</span> <span class="n">crash</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">selected_level</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">levels</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="n">selected_level</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">selected_level</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span>
<span class="k">if</span> <span class="n">selected_level</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">levels</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="n">logger</span><span class="o">.</span><span class="n">screen</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">&quot;The selected level (</span><span class="si">{}</span><span class="s2">) is not part of the available levels (</span><span class="si">{}</span><span class="s2">)&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">selected_level</span><span class="p">,</span> <span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">levels</span><span class="o">.</span><span class="n">keys</span><span class="p">())),</span> <span class="n">crash</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">levels</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">selected_level</span><span class="p">]</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">selected_level</span><span class="p">,</span> <span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">levels</span><span class="o">.</span><span class="n">keys</span><span class="p">())),</span> <span class="n">crash</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">levels</span><span class="p">[</span><span class="n">selected_level</span><span class="p">]</span>
<span class="c1"># class SingleLevelPerPhase(LevelSelection):</span>
@@ -717,7 +718,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -559,6 +559,11 @@
<span class="n">num_actions</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">n</span><span class="p">,</span>
<span class="n">descriptions</span><span class="o">=</span><span class="n">actions_description</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="n">screen</span><span class="o">.</span><span class="n">error</span><span class="p">((</span>
<span class="s2">&quot;Failed to instantiate gym environment class </span><span class="si">{}</span><span class="s2"> due to unsupported &quot;</span>
<span class="s2">&quot;action space </span><span class="si">{}</span><span class="s2">. Expected BoxActionSpace or DiscreteActionSpace.&quot;</span>
<span class="p">)</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">env_class</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">action_space</span><span class="p">),</span> <span class="n">crash</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">human_control</span><span class="p">:</span>
<span class="c1"># TODO: add this to the action space</span>
@@ -741,7 +746,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -466,7 +466,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -268,7 +268,7 @@
<span class="n">action_values_mean</span> <span class="o">=</span> <span class="n">action_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="c1"># step the noise schedule</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="c1"># the second element of the list is assumed to be the standard deviation</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_values</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
@@ -318,7 +318,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -280,7 +280,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -303,7 +303,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -269,7 +269,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -253,7 +253,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -330,7 +330,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -299,7 +299,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -266,7 +266,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -301,7 +301,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -198,7 +198,6 @@
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.agents.dqn_agent</span> <span class="k">import</span> <span class="n">DQNAgentParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.layers</span> <span class="k">import</span> <span class="n">NoisyNetDense</span>
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AgentParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span>
@@ -303,7 +302,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -272,7 +272,7 @@
<span class="n">action_values_mean</span> <span class="o">=</span> <span class="n">action_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="c1"># step the noise schedule</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="c1"># the second element of the list is assumed to be the standard deviation</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_values</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
@@ -325,7 +325,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -307,7 +307,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -288,7 +288,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -288,7 +288,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -296,7 +296,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -249,7 +249,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -277,7 +277,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -274,7 +274,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -262,7 +262,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -309,7 +309,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -282,7 +282,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -267,6 +267,7 @@
<span class="k">def</span> <span class="nf">restore_state_from_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">checkpoint_prefix</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">running_observation_stats</span><span class="o">.</span><span class="n">restore_state_from_checkpoint</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">checkpoint_prefix</span><span class="p">)</span></div>
</pre></div>
</div>
@@ -304,7 +305,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -296,7 +296,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -288,7 +288,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -314,7 +314,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -266,7 +266,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -264,7 +264,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -323,7 +323,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -280,7 +280,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -269,7 +269,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -297,7 +297,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -259,7 +259,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -440,7 +440,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -255,16 +255,25 @@
<span class="k">def</span> <span class="nf">num_transitions_in_complete_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions_in_complete_episodes</span>
<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">is_consecutive_transitions</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sample a batch of transitions form the replay buffer. If the requested size is larger than the number</span>
<span class="sd"> Sample a batch of transitions from the replay buffer. If the requested size is larger than the number</span>
<span class="sd"> of samples available in the replay buffer then the batch will return empty.</span>
<span class="sd"> :param size: the size of the batch to sample</span>
<span class="sd"> :param is_consecutive_transitions: if set True, samples a batch of consecutive transitions.</span>
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_complete_episodes</span><span class="p">()</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">if</span> <span class="n">is_consecutive_transitions</span><span class="p">:</span>
<span class="n">episode_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_complete_episodes</span><span class="p">())</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">length</span><span class="p">()</span> <span class="o">&lt;=</span> <span class="n">size</span><span class="p">:</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">transition_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">length</span><span class="p">())</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">transition_idx</span><span class="o">-</span><span class="n">size</span><span class="p">:</span><span class="n">transition_idx</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">transitions_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">(),</span> <span class="n">size</span><span class="o">=</span><span class="n">size</span><span class="p">)</span>
<span class="n">batch</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">transitions_idx</span><span class="p">]</span>
@@ -523,7 +532,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -363,7 +363,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -288,7 +288,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -248,7 +248,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -388,7 +388,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -511,7 +511,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -455,7 +455,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -514,7 +514,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -268,7 +268,8 @@
<script type="text/javascript" src="../../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -627,7 +627,8 @@
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+28 -11
View File
@@ -299,7 +299,8 @@
<div class="viewcode-block" id="Space.contains"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.contains">[docs]</a> <span class="k">def</span> <span class="nf">contains</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Checks if the given value matches the space definition in terms of shape and values</span>
<span class="sd"> Checks if value is contained by this space. The shape must match and</span>
<span class="sd"> all of the values must be within the low and high bounds.</span>
<span class="sd"> :param val: a value to check</span>
<span class="sd"> :return: True / False depending on if the val matches the space definition</span>
@@ -314,16 +315,16 @@
<span class="k">return</span> <span class="kc">False</span>
<span class="k">return</span> <span class="kc">True</span></div>
<div class="viewcode-block" id="Space.is_valid_index"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.is_valid_index">[docs]</a> <span class="k">def</span> <span class="nf">is_valid_index</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">point</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<div class="viewcode-block" id="Space.is_valid_index"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.is_valid_index">[docs]</a> <span class="k">def</span> <span class="nf">is_valid_index</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Checks if a given multidimensional point is within the bounds of the shape of the space</span>
<span class="sd"> Checks if a given multidimensional index is within the bounds of the shape of the space</span>
<span class="sd"> :param point: a multidimensional point</span>
<span class="sd"> :return: True if the point is within the shape of the space. False otherwise</span>
<span class="sd"> :param index: a multidimensional index</span>
<span class="sd"> :return: True if the index is within the shape of the space. False otherwise</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">point</span><span class="p">)</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_dimensions</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">index</span><span class="p">)</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_dimensions</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">point</span> <span class="o">&lt;</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_dimensions</span><span class="p">))</span> <span class="ow">or</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">point</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">index</span> <span class="o">&lt;</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_dimensions</span><span class="p">))</span> <span class="ow">or</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">index</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">return</span> <span class="kc">True</span></div>
@@ -338,7 +339,21 @@
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span> <span class="o">==</span> <span class="o">-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">)</span> <span class="ow">or</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></div></div>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">val_matches_space_definition</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Space.val_matches_space_definition will be deprecated soon. Use &quot;</span>
<span class="s2">&quot;contains instead.&quot;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">contains</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">is_point_in_space_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">point</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Space.is_point_in_space_shape will be deprecated soon. Use &quot;</span>
<span class="s2">&quot;is_valid_index instead.&quot;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_valid_index</span><span class="p">(</span><span class="n">point</span><span class="p">)</span></div>
<span class="k">class</span> <span class="nc">RewardSpace</span><span class="p">(</span><span class="n">Space</span><span class="p">):</span>
@@ -568,7 +583,8 @@
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">sample_with_info</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActionInfo</span><span class="p">:</span>
<span class="k">return</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">(),</span> <span class="n">action_probability</span><span class="o">=</span><span class="mf">1.</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">(),</span>
<span class="n">all_action_probabilities</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">),</span> <span class="mf">1.</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)))</span>
<span class="k">def</span> <span class="nf">get_description</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">)</span> <span class="o">==</span> <span class="nb">list</span> <span class="ow">and</span> <span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">action</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">):</span>
@@ -615,7 +631,7 @@
<span class="k">return</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">sample_with_info</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActionInfo</span><span class="p">:</span>
<span class="k">return</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">(),</span> <span class="n">action_probability</span><span class="o">=</span><span class="mf">1.</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">))</span>
<span class="k">return</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">(),</span> <span class="n">all_action_probabilities</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">),</span> <span class="mf">1.</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">)))</span>
<span class="k">def</span> <span class="nf">get_description</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">action</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]))</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">action</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]))</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="ow">or</span> \
@@ -856,7 +872,8 @@
<script type="text/javascript" src="../../_static/jquery.js"></script>
<script type="text/javascript" src="../../_static/underscore.js"></script>
<script type="text/javascript" src="../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
@@ -14,6 +14,7 @@ A detailed description of those algorithms can be found by navigating to each of
:caption: Agents
policy_optimization/ac
policy_optimization/acer
imitation/bc
value_optimization/bs_dqn
value_optimization/categorical_dqn
@@ -32,8 +32,8 @@ Training the network
Given a batch of transitions, run them through the network to get the current predictions of the future measurements
per action, and set them as the initial targets for training the network. For each transition
:math:`(s_t,a_t,r_t,s_{t+1} )` in the batch, the target of the network for the action that was taken, is the actual
measurements that were seen in time-steps :math:`t+1,t+2,t+4,t+8,t+16` and :math:`t+32`.
For the actions that were not taken, the targets are the current values.
measurements that were seen in time-steps :math:`t+1,t+2,t+4,t+8,t+16` and :math:`t+32`.
For the actions that were not taken, the targets are the current values.
.. autoclass:: rl_coach.agents.dfp_agent.DFPAlgorithmParameters
@@ -0,0 +1,60 @@
ACER
============
**Actions space:** Discrete
**References:** `Sample Efficient Actor-Critic with Experience Replay <https://arxiv.org/abs/1611.01224>`_
Network Structure
-----------------
.. image:: /_static/img/design_imgs/acer.png
:width: 500px
:align: center
Algorithm Description
---------------------
Choosing an action - Discrete actions
+++++++++++++++++++++++++++++++++++++
The policy network is used in order to predict action probabilites. While training, a sample is taken from a categorical
distribution assigned with these probabilities. When testing, the action with the highest probability is used.
Training the network
++++++++++++++++++++
Each iteration perform one on-policy update with a batch of the last :math:`T_{max}` transitions,
and :math:`n` (replay ratio) off-policy updates from batches of :math:`T_{max}` transitions sampled from the replay buffer.
Each update perform the following procedure:
1. **Calculate state values:**
.. math:: V(s_t) = \mathbb{E}_{a \sim \pi} [Q(s_t,a)]
2. **Calculate Q retrace:**
.. math:: Q^{ret}(s_t,a_t) = r_t +\gamma \bar{\rho}_{t+1}[Q^{ret}(s_{t+1},a_{t+1}) - Q(s_{t+1},a_{t+1})] + \gamma V(s_{t+1})
.. math:: \text{where} \quad \bar{\rho}_{t} = \min{\left\{c,\rho_t\right\}},\quad \rho_t=\frac{\pi (a_t \mid s_t)}{\mu (a_t \mid s_t)}
3. **Accumulate gradients:**
:math:`\bullet` **Policy gradients (with bias correction):**
.. math:: \hat{g}_t^{policy} & = & \bar{\rho}_{t} \nabla \log \pi (a_t \mid s_t) [Q^{ret}(s_t,a_t) - V(s_t)] \\
& & + \mathbb{E}_{a \sim \pi} \left(\left[\frac{\rho_t(a)-c}{\rho_t(a)}\right] \nabla \log \pi (a \mid s_t) [Q(s_t,a) - V(s_t)] \right)
:math:`\bullet` **Q-Head gradients (MSE):**
.. math:: \hat{g}_t^{Q} = (Q^{ret}(s_t,a_t) - Q(s_t,a_t)) \nabla Q(s_t,a_t)\\
4. **(Optional) Trust region update:** change the policy loss gradient w.r.t network output:
.. math:: \hat{g}_t^{trust-region} = \hat{g}_t^{policy} - \max \left\{0, \frac{k^T \hat{g}_t^{policy} - \delta}{\lVert k \rVert_2^2}\right\} k
.. math:: \text{where} \quad k = \nabla D_{KL}[\pi_{avg} \parallel \pi]
The average policy network is an exponential moving average of the parameters of the network (:math:`\theta_{avg}=\alpha\theta_{avg}+(1-\alpha)\theta`).
The goal of the trust region update is to the difference between the updated policy and the average policy to ensure stability.
.. autoclass:: rl_coach.agents.acer_agent.ACERAlgorithmParameters
@@ -19,7 +19,7 @@ Training the network
1. Sample a batch of transitions from the replay buffer.
2. Using the next states from the sampled batch, run the online network in order to find the $Q$ maximizing
2. Using the next states from the sampled batch, run the online network in order to find the :math:`Q` maximizing
action :math:`argmax_a Q(s_{t+1},a)`. For these actions, use the corresponding next states and run the target
network to calculate :math:`Q(s_{t+1},argmax_a Q(s_{t+1},a))`.
@@ -26,7 +26,7 @@ Training the network
use the current states from the sampled batch, and run the online network to get the current Q values predictions.
Set those values as the targets for the actions that were not actually played.
4. For each action that was played, use the following equation for calculating the targets of the network: $$ y_t=r(s_t,a_t)+γ\cdot max_a {Q(s_{t+1},a)} $$
4. For each action that was played, use the following equation for calculating the targets of the network:
:math:`y_t=r(s_t,a_t )+\gamma \cdot max_a Q(s_{t+1})`
5. Finally, train the online network using the current states as inputs, and with the aforementioned targets.
@@ -190,6 +190,14 @@ The algorithms are ordered by their release date in descending order.
learning stability and speed, both for discrete and continuous action spaces.
</span>
</div>
<div class="algorithm discrete on-policy requires-multi-worker" data-year="201707">
<span class="badge">
<a href="components/agents/policy_optimization/acer.html">ACER</a>
<br>
Similar to A3C with the addition of experience replay and off-policy training. to reduce variance and
improve stability it also employs bias correction and trust region optimization techniques.
</span>
</div>
<div class="algorithm continuous off-policy" data-year="201509">
<span class="badge">
<a href="components/agents/policy_optimization/ddpg.html">DDPG</a>
+1 -1
View File
@@ -4,7 +4,7 @@
*
* Sphinx stylesheet -- basic theme.
*
* :copyright: Copyright 2007-2018 by the Sphinx team, see AUTHORS.
* :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
+1 -1
View File
@@ -4,7 +4,7 @@
*
* Sphinx JavaScript utilities for all documentation.
*
* :copyright: Copyright 2007-2018 by the Sphinx team, see AUTHORS.
* :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
-286
View File
@@ -7,290 +7,4 @@ var DOCUMENTATION_OPTIONS = {
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt',
NAVIGATION_WITH_KEYS: false,
SEARCH_LANGUAGE_STOP_WORDS: ["a","and","are","as","at","be","but","by","for","if","in","into","is","it","near","no","not","of","on","or","such","that","the","their","then","there","these","they","this","to","was","will","with"]
};
/* Non-minified version JS is _stemmer.js if file is provided */
/**
* Porter Stemmer
*/
var Stemmer = function() {
var step2list = {
ational: 'ate',
tional: 'tion',
enci: 'ence',
anci: 'ance',
izer: 'ize',
bli: 'ble',
alli: 'al',
entli: 'ent',
eli: 'e',
ousli: 'ous',
ization: 'ize',
ation: 'ate',
ator: 'ate',
alism: 'al',
iveness: 'ive',
fulness: 'ful',
ousness: 'ous',
aliti: 'al',
iviti: 'ive',
biliti: 'ble',
logi: 'log'
};
var step3list = {
icate: 'ic',
ative: '',
alize: 'al',
iciti: 'ic',
ical: 'ic',
ful: '',
ness: ''
};
var c = "[^aeiou]"; // consonant
var v = "[aeiouy]"; // vowel
var C = c + "[^aeiouy]*"; // consonant sequence
var V = v + "[aeiou]*"; // vowel sequence
var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0
var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1
var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1
var s_v = "^(" + C + ")?" + v; // vowel in stem
this.stemWord = function (w) {
var stem;
var suffix;
var firstch;
var origword = w;
if (w.length < 3)
return w;
var re;
var re2;
var re3;
var re4;
firstch = w.substr(0,1);
if (firstch == "y")
w = firstch.toUpperCase() + w.substr(1);
// Step 1a
re = /^(.+?)(ss|i)es$/;
re2 = /^(.+?)([^s])s$/;
if (re.test(w))
w = w.replace(re,"$1$2");
else if (re2.test(w))
w = w.replace(re2,"$1$2");
// Step 1b
re = /^(.+?)eed$/;
re2 = /^(.+?)(ed|ing)$/;
if (re.test(w)) {
var fp = re.exec(w);
re = new RegExp(mgr0);
if (re.test(fp[1])) {
re = /.$/;
w = w.replace(re,"");
}
}
else if (re2.test(w)) {
var fp = re2.exec(w);
stem = fp[1];
re2 = new RegExp(s_v);
if (re2.test(stem)) {
w = stem;
re2 = /(at|bl|iz)$/;
re3 = new RegExp("([^aeiouylsz])\\1$");
re4 = new RegExp("^" + C + v + "[^aeiouwxy]$");
if (re2.test(w))
w = w + "e";
else if (re3.test(w)) {
re = /.$/;
w = w.replace(re,"");
}
else if (re4.test(w))
w = w + "e";
}
}
// Step 1c
re = /^(.+?)y$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
re = new RegExp(s_v);
if (re.test(stem))
w = stem + "i";
}
// Step 2
re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
suffix = fp[2];
re = new RegExp(mgr0);
if (re.test(stem))
w = stem + step2list[suffix];
}
// Step 3
re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
suffix = fp[2];
re = new RegExp(mgr0);
if (re.test(stem))
w = stem + step3list[suffix];
}
// Step 4
re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/;
re2 = /^(.+?)(s|t)(ion)$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
re = new RegExp(mgr1);
if (re.test(stem))
w = stem;
}
else if (re2.test(w)) {
var fp = re2.exec(w);
stem = fp[1] + fp[2];
re2 = new RegExp(mgr1);
if (re2.test(stem))
w = stem;
}
// Step 5
re = /^(.+?)e$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
re = new RegExp(mgr1);
re2 = new RegExp(meq1);
re3 = new RegExp("^" + C + v + "[^aeiouwxy]$");
if (re.test(stem) || (re2.test(stem) && !(re3.test(stem))))
w = stem;
}
re = /ll$/;
re2 = new RegExp(mgr1);
if (re.test(w) && re2.test(w)) {
re = /.$/;
w = w.replace(re,"");
}
// and turn initial Y back to y
if (firstch == "y")
w = firstch.toLowerCase() + w.substr(1);
return w;
}
}
var splitChars = (function() {
var result = {};
var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648,
1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702,
2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971,
2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345,
3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761,
3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823,
4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125,
8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695,
11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587,
43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141];
var i, j, start, end;
for (i = 0; i < singles.length; i++) {
result[singles[i]] = true;
}
var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709],
[722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161],
[1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568],
[1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807],
[1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047],
[2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383],
[2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450],
[2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547],
[2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673],
[2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820],
[2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946],
[2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023],
[3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173],
[3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332],
[3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481],
[3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718],
[3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791],
[3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095],
[4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205],
[4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687],
[4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968],
[4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869],
[5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102],
[6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271],
[6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592],
[6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822],
[6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167],
[7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959],
[7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143],
[8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318],
[8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483],
[8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101],
[10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567],
[11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292],
[12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444],
[12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783],
[12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311],
[19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511],
[42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774],
[42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071],
[43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263],
[43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519],
[43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647],
[43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967],
[44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295],
[57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274],
[64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007],
[65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381],
[65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]];
for (i = 0; i < ranges.length; i++) {
start = ranges[i][0];
end = ranges[i][1];
for (j = start; j <= end; j++) {
result[j] = true;
}
}
return result;
})();
function splitQuery(query) {
var result = [];
var start = -1;
for (var i = 0; i < query.length; i++) {
if (splitChars[query.charCodeAt(i)]) {
if (start !== -1) {
result.push(query.slice(start, i));
start = -1;
}
} else if (start === -1) {
start = i;
}
}
if (start !== -1) {
result.push(query.slice(start));
}
return result;
}
+297
View File
@@ -0,0 +1,297 @@
/*
* language_data.js
* ~~~~~~~~~~~~~~~~
*
* This script contains the language-specific data used by searchtools.js,
* namely the list of stopwords, stemmer, scorer and splitter.
*
* :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
var stopwords = ["a","and","are","as","at","be","but","by","for","if","in","into","is","it","near","no","not","of","on","or","such","that","the","their","then","there","these","they","this","to","was","will","with"];
/* Non-minified version JS is _stemmer.js if file is provided */
/**
* Porter Stemmer
*/
var Stemmer = function() {
var step2list = {
ational: 'ate',
tional: 'tion',
enci: 'ence',
anci: 'ance',
izer: 'ize',
bli: 'ble',
alli: 'al',
entli: 'ent',
eli: 'e',
ousli: 'ous',
ization: 'ize',
ation: 'ate',
ator: 'ate',
alism: 'al',
iveness: 'ive',
fulness: 'ful',
ousness: 'ous',
aliti: 'al',
iviti: 'ive',
biliti: 'ble',
logi: 'log'
};
var step3list = {
icate: 'ic',
ative: '',
alize: 'al',
iciti: 'ic',
ical: 'ic',
ful: '',
ness: ''
};
var c = "[^aeiou]"; // consonant
var v = "[aeiouy]"; // vowel
var C = c + "[^aeiouy]*"; // consonant sequence
var V = v + "[aeiou]*"; // vowel sequence
var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0
var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1
var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1
var s_v = "^(" + C + ")?" + v; // vowel in stem
this.stemWord = function (w) {
var stem;
var suffix;
var firstch;
var origword = w;
if (w.length < 3)
return w;
var re;
var re2;
var re3;
var re4;
firstch = w.substr(0,1);
if (firstch == "y")
w = firstch.toUpperCase() + w.substr(1);
// Step 1a
re = /^(.+?)(ss|i)es$/;
re2 = /^(.+?)([^s])s$/;
if (re.test(w))
w = w.replace(re,"$1$2");
else if (re2.test(w))
w = w.replace(re2,"$1$2");
// Step 1b
re = /^(.+?)eed$/;
re2 = /^(.+?)(ed|ing)$/;
if (re.test(w)) {
var fp = re.exec(w);
re = new RegExp(mgr0);
if (re.test(fp[1])) {
re = /.$/;
w = w.replace(re,"");
}
}
else if (re2.test(w)) {
var fp = re2.exec(w);
stem = fp[1];
re2 = new RegExp(s_v);
if (re2.test(stem)) {
w = stem;
re2 = /(at|bl|iz)$/;
re3 = new RegExp("([^aeiouylsz])\\1$");
re4 = new RegExp("^" + C + v + "[^aeiouwxy]$");
if (re2.test(w))
w = w + "e";
else if (re3.test(w)) {
re = /.$/;
w = w.replace(re,"");
}
else if (re4.test(w))
w = w + "e";
}
}
// Step 1c
re = /^(.+?)y$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
re = new RegExp(s_v);
if (re.test(stem))
w = stem + "i";
}
// Step 2
re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
suffix = fp[2];
re = new RegExp(mgr0);
if (re.test(stem))
w = stem + step2list[suffix];
}
// Step 3
re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
suffix = fp[2];
re = new RegExp(mgr0);
if (re.test(stem))
w = stem + step3list[suffix];
}
// Step 4
re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/;
re2 = /^(.+?)(s|t)(ion)$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
re = new RegExp(mgr1);
if (re.test(stem))
w = stem;
}
else if (re2.test(w)) {
var fp = re2.exec(w);
stem = fp[1] + fp[2];
re2 = new RegExp(mgr1);
if (re2.test(stem))
w = stem;
}
// Step 5
re = /^(.+?)e$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
re = new RegExp(mgr1);
re2 = new RegExp(meq1);
re3 = new RegExp("^" + C + v + "[^aeiouwxy]$");
if (re.test(stem) || (re2.test(stem) && !(re3.test(stem))))
w = stem;
}
re = /ll$/;
re2 = new RegExp(mgr1);
if (re.test(w) && re2.test(w)) {
re = /.$/;
w = w.replace(re,"");
}
// and turn initial Y back to y
if (firstch == "y")
w = firstch.toLowerCase() + w.substr(1);
return w;
}
}
var splitChars = (function() {
var result = {};
var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648,
1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702,
2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971,
2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345,
3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761,
3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823,
4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125,
8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695,
11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587,
43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141];
var i, j, start, end;
for (i = 0; i < singles.length; i++) {
result[singles[i]] = true;
}
var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709],
[722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161],
[1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568],
[1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807],
[1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047],
[2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383],
[2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450],
[2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547],
[2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673],
[2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820],
[2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946],
[2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023],
[3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173],
[3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332],
[3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481],
[3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718],
[3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791],
[3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095],
[4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205],
[4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687],
[4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968],
[4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869],
[5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102],
[6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271],
[6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592],
[6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822],
[6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167],
[7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959],
[7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143],
[8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318],
[8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483],
[8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101],
[10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567],
[11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292],
[12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444],
[12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783],
[12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311],
[19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511],
[42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774],
[42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071],
[43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263],
[43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519],
[43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647],
[43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967],
[44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295],
[57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274],
[64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007],
[65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381],
[65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]];
for (i = 0; i < ranges.length; i++) {
start = ranges[i][0];
end = ranges[i][1];
for (j = start; j <= end; j++) {
result[j] = true;
}
}
return result;
})();
function splitQuery(query) {
var result = [];
var start = -1;
for (var i = 0; i < query.length; i++) {
if (splitChars[query.charCodeAt(i)]) {
if (start !== -1) {
result.push(query.slice(start, i));
start = -1;
}
} else if (start === -1) {
start = i;
}
}
if (start !== -1) {
result.push(query.slice(start));
}
return result;
}
+1 -2
View File
@@ -4,7 +4,7 @@
*
* Sphinx JavaScript utilities for the full-text search.
*
* :copyright: Copyright 2007-2018 by the Sphinx team, see AUTHORS.
* :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
@@ -138,7 +138,6 @@ var Search = {
*/
query : function(query) {
var i;
var stopwords = DOCUMENTATION_OPTIONS.SEARCH_LANGUAGE_STOP_WORDS;
// stem the searchterms and add them to the correct list
var stemmer = new Stemmer();
+1 -1
View File
@@ -4,7 +4,7 @@
*
* sphinx.websupport utilities for all documentation.
*
* :copyright: Copyright 2007-2018 by the Sphinx team, see AUTHORS.
* :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
+2 -1
View File
@@ -382,7 +382,8 @@ assigned</li>
<script type="text/javascript" src="../_static/jquery.js"></script>
<script type="text/javascript" src="../_static/underscore.js"></script>
<script type="text/javascript" src="../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+5 -3
View File
@@ -30,7 +30,7 @@
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link rel="next" title="Bootstrapped DQN" href="../value_optimization/bs_dqn.html" />
<link rel="prev" title="Actor-Critic" href="../policy_optimization/ac.html" />
<link rel="prev" title="ACER" href="../policy_optimization/acer.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
@@ -107,6 +107,7 @@
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Agents</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ac.html">Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/acer.html">ACER</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Behavioral Cloning</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
<li class="toctree-l3"><a class="reference internal" href="#algorithm-description">Algorithm Description</a><ul>
@@ -252,7 +253,7 @@ the expert for each state.</p>
<a href="../value_optimization/bs_dqn.html" class="btn btn-neutral float-right" title="Bootstrapped DQN" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="../policy_optimization/ac.html" class="btn btn-neutral" title="Actor-Critic" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
<a href="../policy_optimization/acer.html" class="btn btn-neutral" title="ACER" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
@@ -286,7 +287,8 @@ the expert for each state.</p>
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+3 -1
View File
@@ -107,6 +107,7 @@
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Agents</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ac.html">Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/acer.html">ACER</a></li>
<li class="toctree-l2"><a class="reference internal" href="bc.html">Behavioral Cloning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/bs_dqn.html">Bootstrapped DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/categorical_dqn.html">Categorical DQN</a></li>
@@ -301,7 +302,8 @@ The key of the state dictionary which corresponds to the value that will be used
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
+15 -15
View File
@@ -107,6 +107,7 @@
<ul class="current">
<li class="toctree-l1 current"><a class="current reference internal" href="#">Agents</a><ul>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/ac.html">Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/acer.html">ACER</a></li>
<li class="toctree-l2"><a class="reference internal" href="imitation/bc.html">Behavioral Cloning</a></li>
<li class="toctree-l2"><a class="reference internal" href="value_optimization/bs_dqn.html">Bootstrapped DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="value_optimization/categorical_dqn.html">Categorical DQN</a></li>
@@ -213,6 +214,7 @@ A detailed description of those algorithms can be found by navigating to each of
<p class="caption"><span class="caption-text">Agents</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/ac.html">Actor-Critic</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/acer.html">ACER</a></li>
<li class="toctree-l1"><a class="reference internal" href="imitation/bc.html">Behavioral Cloning</a></li>
<li class="toctree-l1"><a class="reference internal" href="value_optimization/bs_dqn.html">Bootstrapped DQN</a></li>
<li class="toctree-l1"><a class="reference internal" href="value_optimization/categorical_dqn.html">Categorical DQN</a></li>
@@ -334,17 +336,9 @@ training or testing.</p>
<dt id="rl_coach.agents.agent.Agent.collect_savers">
<code class="descname">collect_savers</code><span class="sig-paren">(</span><em>parent_path_suffix: str</em><span class="sig-paren">)</span> &#x2192; rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.collect_savers" title="Permalink to this definition"></a></dt>
<dd><p>Collect all of agents network savers
:param parent_path_suffix: path suffix of the parent of the agent</p>
<blockquote>
<div>(could be name of level manager or composite agent)</div></blockquote>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">collection of all agent savers</td>
</tr>
</tbody>
</table>
:param parent_path_suffix: path suffix of the parent of the agent
(could be name of level manager or composite agent)
:return: collection of all agent savers</p>
</dd></dl>
<dl class="method">
@@ -640,15 +634,20 @@ by val, and by the current phase set in self.phase.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference">
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.run_pre_network_filter_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference" title="Permalink to this definition"></a></dt>
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.run_pre_network_filter_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Run filters which where defined for being applied right before using the state for inference.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>state</strong> The state to run the filters on</td>
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>state</strong> The state to run the filters on</li>
<li><strong>update_filter_internal_state</strong> Should update the filters internal state - should not update when evaluating</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">The filtered state</td>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last">The filtered state</p>
</td>
</tr>
</tbody>
</table>
@@ -860,7 +859,8 @@ Can be useful for agents that want to tweak the reward, termination signal, etc.
<script type="text/javascript" src="../../_static/jquery.js"></script>
<script type="text/javascript" src="../../_static/underscore.js"></script>
<script type="text/javascript" src="../../_static/doctools.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>

Some files were not shown because too many files have changed in this diff Show More