mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
update of api docstrings across coach and tutorials [WIP] (#91)
* updating the documentation website * adding the built docs * update of api docstrings across coach and tutorials 0-2 * added some missing api documentation * New Sphinx based documentation
This commit is contained in:
413
docs/_modules/rl_coach/agents/actor_critic_agent.html
Normal file
413
docs/_modules/rl_coach/agents/actor_critic_agent.html
Normal file
@@ -0,0 +1,413 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.actor_critic_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.actor_critic_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.actor_critic_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
<span class="kn">import</span> <span class="nn">scipy.signal</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.policy_optimization_agent</span> <span class="k">import</span> <span class="n">PolicyOptimizationAgent</span><span class="p">,</span> <span class="n">PolicyGradientRescaler</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">PolicyHeadParameters</span><span class="p">,</span> <span class="n">VHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> \
|
||||
<span class="n">AgentParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.categorical</span> <span class="k">import</span> <span class="n">CategoricalParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.continuous_entropy</span> <span class="k">import</span> <span class="n">ContinuousEntropyParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.single_episode_buffer</span> <span class="k">import</span> <span class="n">SingleEpisodeBufferParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">last_sample</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="ActorCriticAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/ac.html#rl_coach.agents.actor_critic_agent.ActorCriticAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">ActorCriticAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param policy_gradient_rescaler: (PolicyGradientRescaler)</span>
|
||||
<span class="sd"> The value that will be used to rescale the policy gradient</span>
|
||||
|
||||
<span class="sd"> :param apply_gradients_every_x_episodes: (int)</span>
|
||||
<span class="sd"> The number of episodes to wait before applying the accumulated gradients to the network.</span>
|
||||
<span class="sd"> The training iterations only accumulate gradients without actually applying them.</span>
|
||||
|
||||
<span class="sd"> :param beta_entropy: (float)</span>
|
||||
<span class="sd"> The weight that will be given to the entropy regularization which is used in order to improve exploration.</span>
|
||||
|
||||
<span class="sd"> :param num_steps_between_gradient_updates: (int)</span>
|
||||
<span class="sd"> Every num_steps_between_gradient_updates transitions will be considered as a single batch and use for</span>
|
||||
<span class="sd"> accumulating gradients. This is also the number of steps used for bootstrapping according to the n-step formulation.</span>
|
||||
|
||||
<span class="sd"> :param gae_lambda: (float)</span>
|
||||
<span class="sd"> If the policy gradient rescaler was defined as PolicyGradientRescaler.GAE, the generalized advantage estimation</span>
|
||||
<span class="sd"> scheme will be used, in which case the lambda value controls the decay for the different n-step lengths.</span>
|
||||
|
||||
<span class="sd"> :param estimate_state_value_using_gae: (bool)</span>
|
||||
<span class="sd"> If set to True, the state value targets for the V head will be estimated using the GAE scheme.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">=</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">A_VALUE</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_every_x_episodes</span> <span class="o">=</span> <span class="mi">5</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">beta_entropy</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_gradient_updates</span> <span class="o">=</span> <span class="mi">5000</span> <span class="c1"># this is called t_max in all the papers</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">gae_lambda</span> <span class="o">=</span> <span class="mf">0.96</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">estimate_state_value_using_gae</span> <span class="o">=</span> <span class="kc">False</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">ActorCriticNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">VHeadParameters</span><span class="p">(</span><span class="n">loss_weight</span><span class="o">=</span><span class="mf">0.5</span><span class="p">),</span> <span class="n">PolicyHeadParameters</span><span class="p">(</span><span class="n">loss_weight</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">clip_gradients</span> <span class="o">=</span> <span class="mf">40.0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">ActorCriticAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">ActorCriticAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="p">{</span><span class="n">DiscreteActionSpace</span><span class="p">:</span> <span class="n">CategoricalParameters</span><span class="p">(),</span>
|
||||
<span class="n">BoxActionSpace</span><span class="p">:</span> <span class="n">ContinuousEntropyParameters</span><span class="p">()},</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">SingleEpisodeBufferParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">ActorCriticNetworkParameters</span><span class="p">()})</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.actor_critic_agent:ActorCriticAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Actor Critic - https://arxiv.org/abs/1602.01783</span>
|
||||
<span class="k">class</span> <span class="nc">ActorCriticAgent</span><span class="p">(</span><span class="n">PolicyOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_gradient_update_step_idx</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_advantages</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Advantages'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">state_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Values'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Value Loss'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Policy Loss'</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Discounting function used to calculate discounted returns.</span>
|
||||
<span class="k">def</span> <span class="nf">discount</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">gamma</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">scipy</span><span class="o">.</span><span class="n">signal</span><span class="o">.</span><span class="n">lfilter</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="n">gamma</span><span class="p">],</span> <span class="n">x</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_general_advantage_estimation_values</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">values</span><span class="p">):</span>
|
||||
<span class="c1"># values contain n+1 elements (t ... t+n+1), rewards contain n elements (t ... t + n)</span>
|
||||
<span class="n">bootstrap_extended_rewards</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rewards</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="o">+</span> <span class="p">[</span><span class="n">values</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]])</span>
|
||||
|
||||
<span class="c1"># Approximation based calculation of GAE (mathematically correct only when Tmax = inf,</span>
|
||||
<span class="c1"># although in practice works even in much smaller Tmax values, e.g. 20)</span>
|
||||
<span class="n">deltas</span> <span class="o">=</span> <span class="n">rewards</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">values</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">-</span> <span class="n">values</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">gae</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discount</span><span class="p">(</span><span class="n">deltas</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">gae_lambda</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">estimate_state_value_using_gae</span><span class="p">:</span>
|
||||
<span class="n">discounted_returns</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">gae</span> <span class="o">+</span> <span class="n">values</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">discounted_returns</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discount</span><span class="p">(</span><span class="n">bootstrap_extended_rewards</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span><span class="p">)),</span> <span class="mi">1</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="k">return</span> <span class="n">gae</span><span class="p">,</span> <span class="n">discounted_returns</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="c1"># batch contains a list of episodes to learn from</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># get the values for the current states</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="n">current_state_values</span> <span class="o">=</span> <span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">state_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">current_state_values</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># the targets for the state value estimator</span>
|
||||
<span class="n">num_transitions</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">size</span>
|
||||
<span class="n">state_value_head_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_transitions</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># estimate the advantage function</span>
|
||||
<span class="n">action_advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_transitions</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">A_VALUE</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="n">R</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">R</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">last_sample</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)))[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_transitions</span><span class="p">)):</span>
|
||||
<span class="n">R</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">R</span>
|
||||
<span class="n">state_value_head_targets</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">R</span>
|
||||
<span class="n">action_advantages</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">R</span> <span class="o">-</span> <span class="n">current_state_values</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">GAE</span><span class="p">:</span>
|
||||
<span class="c1"># get bootstraps</span>
|
||||
<span class="n">bootstrapped_value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">last_sample</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)))[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">current_state_values</span><span class="p">,</span> <span class="n">bootstrapped_value</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="n">values</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
|
||||
<span class="c1"># get general discounted returns table</span>
|
||||
<span class="n">gae_values</span><span class="p">,</span> <span class="n">state_value_head_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_general_advantage_estimation_values</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(),</span> <span class="n">values</span><span class="p">)</span>
|
||||
<span class="n">action_advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">gae_values</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"WARNING: The requested policy gradient rescaler is not available"</span><span class="p">)</span>
|
||||
|
||||
<span class="n">action_advantages</span> <span class="o">=</span> <span class="n">action_advantages</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">actions</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">actions</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o"><</span> <span class="mi">2</span><span class="p">:</span>
|
||||
<span class="n">actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">actions</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># train</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">({</span><span class="o">**</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span>
|
||||
<span class="s1">'output_1_0'</span><span class="p">:</span> <span class="n">actions</span><span class="p">},</span>
|
||||
<span class="p">[</span><span class="n">state_value_head_targets</span><span class="p">,</span> <span class="n">action_advantages</span><span class="p">])</span>
|
||||
|
||||
<span class="c1"># logging</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_advantages</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">action_advantages</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">unclipped_grads</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">unclipped_grads</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">losses</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">losses</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
|
||||
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s2">"main"</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)[</span><span class="mi">1</span><span class="p">:]</span> <span class="c1"># index 0 is the state value</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
1153
docs/_modules/rl_coach/agents/agent.html
Normal file
1153
docs/_modules/rl_coach/agents/agent.html
Normal file
File diff suppressed because it is too large
Load Diff
308
docs/_modules/rl_coach/agents/bc_agent.html
Normal file
308
docs/_modules/rl_coach/agents/bc_agent.html
Normal file
@@ -0,0 +1,308 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.bc_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.bc_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.bc_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.imitation_agent</span> <span class="k">import</span> <span class="n">ImitationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">PolicyHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AgentParameters</span><span class="p">,</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> \
|
||||
<span class="n">MiddlewareScheme</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.e_greedy</span> <span class="k">import</span> <span class="n">EGreedyParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.experience_replay</span> <span class="k">import</span> <span class="n">ExperienceReplayParameters</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="BCAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/imitation/bc.html#rl_coach.agents.bc_agent.BCAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">BCAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">BCNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">MiddlewareScheme</span><span class="o">.</span><span class="n">Medium</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">PolicyHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">BCAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">BCAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">EGreedyParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">ExperienceReplayParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">BCNetworkParameters</span><span class="p">()})</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.bc_agent:BCAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Behavioral Cloning Agent</span>
|
||||
<span class="k">class</span> <span class="nc">BCAgent</span><span class="p">(</span><span class="n">ImitationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># When using a policy head, the targets refer to the advantages that we are normally feeding the head with.</span>
|
||||
<span class="c1"># In this case, we need the policy head to just predict probabilities, so while we usually train the network</span>
|
||||
<span class="c1"># with log(Pi)*Advantages, in this specific case we will train it to log(Pi), which after the softmax will</span>
|
||||
<span class="c1"># predict Pi (=probabilities)</span>
|
||||
<span class="n">targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">({</span><span class="o">**</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span>
|
||||
<span class="s1">'output_0_0'</span><span class="p">:</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()},</span>
|
||||
<span class="n">targets</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
382
docs/_modules/rl_coach/agents/categorical_dqn_agent.html
Normal file
382
docs/_modules/rl_coach/agents/categorical_dqn_agent.html
Normal file
@@ -0,0 +1,382 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.categorical_dqn_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.categorical_dqn_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.categorical_dqn_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.dqn_agent</span> <span class="k">import</span> <span class="n">DQNNetworkParameters</span><span class="p">,</span> <span class="n">DQNAlgorithmParameters</span><span class="p">,</span> <span class="n">DQNAgentParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.value_optimization_agent</span> <span class="k">import</span> <span class="n">ValueOptimizationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">CategoricalQHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">StateType</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.e_greedy</span> <span class="k">import</span> <span class="n">EGreedyParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.prioritized_experience_replay</span> <span class="k">import</span> <span class="n">PrioritizedExperienceReplay</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.schedules</span> <span class="k">import</span> <span class="n">LinearSchedule</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">CategoricalDQNNetworkParameters</span><span class="p">(</span><span class="n">DQNNetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">CategoricalQHeadParameters</span><span class="p">()]</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="CategoricalDQNAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/categorical_dqn.html#rl_coach.agents.categorical_dqn_agent.CategoricalDQNAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">CategoricalDQNAlgorithmParameters</span><span class="p">(</span><span class="n">DQNAlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param v_min: (float)</span>
|
||||
<span class="sd"> The minimal value that will be represented in the network output for predicting the Q value.</span>
|
||||
<span class="sd"> Corresponds to :math:`v_{min}` in the paper.</span>
|
||||
|
||||
<span class="sd"> :param v_max: (float)</span>
|
||||
<span class="sd"> The maximum value that will be represented in the network output for predicting the Q value.</span>
|
||||
<span class="sd"> Corresponds to :math:`v_{max}` in the paper.</span>
|
||||
|
||||
<span class="sd"> :param atoms: (int)</span>
|
||||
<span class="sd"> The number of atoms that will be used to discretize the range between v_min and v_max.</span>
|
||||
<span class="sd"> For the C51 algorithm described in the paper, the number of atoms is 51.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">v_min</span> <span class="o">=</span> <span class="o">-</span><span class="mf">10.0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">v_max</span> <span class="o">=</span> <span class="mf">10.0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">atoms</span> <span class="o">=</span> <span class="mi">51</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">CategoricalDQNExplorationParameters</span><span class="p">(</span><span class="n">EGreedyParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">epsilon_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_epsilon</span> <span class="o">=</span> <span class="mf">0.001</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">CategoricalDQNAgentParameters</span><span class="p">(</span><span class="n">DQNAgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">algorithm</span> <span class="o">=</span> <span class="n">CategoricalDQNAlgorithmParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span> <span class="o">=</span> <span class="n">CategoricalDQNExplorationParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">CategoricalDQNNetworkParameters</span><span class="p">()}</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.categorical_dqn_agent:CategoricalDQNAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Categorical Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf</span>
|
||||
<span class="k">class</span> <span class="nc">CategoricalDQNAgent</span><span class="p">(</span><span class="n">ValueOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">z_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">v_min</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">v_max</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">distribution_prediction_to_q_values</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># prediction's format is (batch,actions,atoms)</span>
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
|
||||
<span class="n">prediction</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
|
||||
<span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">prediction</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">q_values</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">return</span> <span class="n">q_values</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># for the action we actually took, the error is calculated by the atoms distribution</span>
|
||||
<span class="c1"># for all other actions, the error is 0</span>
|
||||
<span class="n">distributional_q_st_plus_1</span><span class="p">,</span> <span class="n">TD_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">parallel_prediction</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="c1"># select the optimal actions for the next state</span>
|
||||
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">distributional_q_st_plus_1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||||
|
||||
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.</span>
|
||||
<span class="c1"># only 10% speedup overall - leaving commented out as the code is not as clear.</span>
|
||||
|
||||
<span class="c1"># tzj_ = np.fmax(np.fmin(batch.rewards() + (1.0 - batch.game_overs()) * self.ap.algorithm.discount *</span>
|
||||
<span class="c1"># np.transpose(np.repeat(self.z_values[np.newaxis, :], batch.size, axis=0), (1, 0)),</span>
|
||||
<span class="c1"># self.z_values[-1]),</span>
|
||||
<span class="c1"># self.z_values[0])</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])</span>
|
||||
<span class="c1"># u_ = (np.ceil(bj_)).astype(int)</span>
|
||||
<span class="c1"># l_ = (np.floor(bj_)).astype(int)</span>
|
||||
<span class="c1"># m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))</span>
|
||||
<span class="c1"># np.add.at(m_, [batches, l_],</span>
|
||||
<span class="c1"># np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))</span>
|
||||
<span class="c1"># np.add.at(m_, [batches, u_],</span>
|
||||
<span class="c1"># np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (bj_ - l_))</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="n">tzj</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">fmax</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">fmin</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()</span> <span class="o">+</span>
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">())</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="n">j</span><span class="p">],</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="n">bj</span> <span class="o">=</span> <span class="p">(</span><span class="n">tzj</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">/</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="n">u</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">bj</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
|
||||
<span class="n">l</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">bj</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
|
||||
<span class="n">m</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">l</span><span class="p">]</span> <span class="o">+=</span> <span class="p">(</span><span class="n">distributional_q_st_plus_1</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">target_actions</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">u</span> <span class="o">-</span> <span class="n">bj</span><span class="p">))</span>
|
||||
<span class="n">m</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">u</span><span class="p">]</span> <span class="o">+=</span> <span class="p">(</span><span class="n">distributional_q_st_plus_1</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">target_actions</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">bj</span> <span class="o">-</span> <span class="n">l</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># total_loss = cross entropy between actual result above and predicted result for the given action</span>
|
||||
<span class="c1"># only update the action that we have actually done in this transition</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()]</span> <span class="o">=</span> <span class="n">m</span>
|
||||
|
||||
<span class="c1"># update errors in prioritized replay buffer</span>
|
||||
<span class="n">importance_weights</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'weight'</span><span class="p">)</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">PrioritizedExperienceReplay</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">TD_targets</span><span class="p">,</span>
|
||||
<span class="n">importance_weights</span><span class="o">=</span><span class="n">importance_weights</span><span class="p">)</span>
|
||||
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># TODO: fix this spaghetti code</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">PrioritizedExperienceReplay</span><span class="p">):</span>
|
||||
<span class="n">errors</span> <span class="o">=</span> <span class="n">losses</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'update_priorities'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'idx'</span><span class="p">),</span> <span class="n">errors</span><span class="p">))</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
314
docs/_modules/rl_coach/agents/cil_agent.html
Normal file
314
docs/_modules/rl_coach/agents/cil_agent.html
Normal file
@@ -0,0 +1,314 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.cil_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.cil_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.cil_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.imitation_agent</span> <span class="k">import</span> <span class="n">ImitationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">RegressionHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AgentParameters</span><span class="p">,</span> <span class="n">MiddlewareScheme</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AlgorithmParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.e_greedy</span> <span class="k">import</span> <span class="n">EGreedyParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.balanced_experience_replay</span> <span class="k">import</span> <span class="n">BalancedExperienceReplayParameters</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="CILAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/imitation/cil.html#rl_coach.agents.cil_agent.CILAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">CILAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param state_key_with_the_class_index: (str)</span>
|
||||
<span class="sd"> The key of the state dictionary which corresponds to the value that will be used to control the class index.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">state_key_with_the_class_index</span> <span class="o">=</span> <span class="s1">'high_level_command'</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">CILNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">MiddlewareScheme</span><span class="o">.</span><span class="n">Medium</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">RegressionHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">CILAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">CILAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">EGreedyParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">BalancedExperienceReplayParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">CILNetworkParameters</span><span class="p">()})</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.cil_agent:CILAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Conditional Imitation Learning Agent: https://arxiv.org/abs/1710.02410</span>
|
||||
<span class="k">class</span> <span class="nc">CILAgent</span><span class="p">(</span><span class="n">ImitationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_high_level_control</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_high_level_control</span> <span class="o">=</span> <span class="n">curr_state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">state_key_with_the_class_index</span><span class="p">]</span>
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">choose_action</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">extract_action_values</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">prediction</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">current_high_level_control</span><span class="p">]</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="n">target_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">({</span><span class="o">**</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)})</span>
|
||||
|
||||
<span class="n">branch_to_update</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">state_key_with_the_class_index</span><span class="p">])[</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">state_key_with_the_class_index</span><span class="p">]</span>
|
||||
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">branch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">branch_to_update</span><span class="p">):</span>
|
||||
<span class="n">target_values</span><span class="p">[</span><span class="n">branch</span><span class="p">][</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">idx</span><span class="p">]</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">({</span><span class="o">**</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)},</span> <span class="n">target_values</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
563
docs/_modules/rl_coach/agents/clipped_ppo_agent.html
Normal file
563
docs/_modules/rl_coach/agents/clipped_ppo_agent.html
Normal file
@@ -0,0 +1,563 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.clipped_ppo_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.clipped_ppo_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.clipped_ppo_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">copy</span>
|
||||
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
|
||||
<span class="kn">from</span> <span class="nn">random</span> <span class="k">import</span> <span class="n">shuffle</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.actor_critic_agent</span> <span class="k">import</span> <span class="n">ActorCriticAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.policy_optimization_agent</span> <span class="k">import</span> <span class="n">PolicyGradientRescaler</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">PPOHeadParameters</span><span class="p">,</span> <span class="n">VHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> \
|
||||
<span class="n">AgentParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">Batch</span><span class="p">,</span> <span class="n">EnvResponse</span><span class="p">,</span> <span class="n">StateType</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.additive_noise</span> <span class="k">import</span> <span class="n">AdditiveNoiseParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.categorical</span> <span class="k">import</span> <span class="n">CategoricalParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.schedules</span> <span class="k">import</span> <span class="n">ConstantSchedule</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">ClippedPPONetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'tanh'</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'tanh'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">VHeadParameters</span><span class="p">(),</span> <span class="n">PPOHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">clip_gradients</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_separate_networks_per_head</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">l2_regularization</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
|
||||
<span class="c1"># The target network is used in order to freeze the old policy, while making updates to the new one</span>
|
||||
<span class="c1"># in train_network()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shared_optimizer</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">scale_down_gradients_by_number_of_workers_for_sync_training</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="ClippedPPOAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/cppo.html#rl_coach.agents.clipped_ppo_agent.ClippedPPOAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">ClippedPPOAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param policy_gradient_rescaler: (PolicyGradientRescaler)</span>
|
||||
<span class="sd"> This represents how the critic will be used to update the actor. The critic value function is typically used</span>
|
||||
<span class="sd"> to rescale the gradients calculated by the actor. There are several ways for doing this, such as using the</span>
|
||||
<span class="sd"> advantage of the action, or the generalized advantage estimation (GAE) value.</span>
|
||||
|
||||
<span class="sd"> :param gae_lambda: (float)</span>
|
||||
<span class="sd"> The :math:`\lambda` value is used within the GAE function in order to weight different bootstrap length</span>
|
||||
<span class="sd"> estimations. Typical values are in the range 0.9-1, and define an exponential decay over the different</span>
|
||||
<span class="sd"> n-step estimations.</span>
|
||||
|
||||
<span class="sd"> :param clip_likelihood_ratio_using_epsilon: (float)</span>
|
||||
<span class="sd"> If not None, the likelihood ratio between the current and new policy in the PPO loss function will be</span>
|
||||
<span class="sd"> clipped to the range [1-clip_likelihood_ratio_using_epsilon, 1+clip_likelihood_ratio_using_epsilon].</span>
|
||||
<span class="sd"> This is typically used in the Clipped PPO version of PPO, and should be set to None in regular PPO</span>
|
||||
<span class="sd"> implementations.</span>
|
||||
|
||||
<span class="sd"> :param value_targets_mix_fraction: (float)</span>
|
||||
<span class="sd"> The targets for the value network are an exponential weighted moving average which uses this mix fraction to</span>
|
||||
<span class="sd"> define how much of the new targets will be taken into account when calculating the loss.</span>
|
||||
<span class="sd"> This value should be set to the range (0,1], where 1 means that only the new targets will be taken into account.</span>
|
||||
|
||||
<span class="sd"> :param estimate_state_value_using_gae: (bool)</span>
|
||||
<span class="sd"> If set to True, the state value will be estimated using the GAE technique.</span>
|
||||
|
||||
<span class="sd"> :param use_kl_regularization: (bool)</span>
|
||||
<span class="sd"> If set to True, the loss function will be regularized using the KL diveregence between the current and new</span>
|
||||
<span class="sd"> policy, to bound the change of the policy during the network update.</span>
|
||||
|
||||
<span class="sd"> :param beta_entropy: (float)</span>
|
||||
<span class="sd"> An entropy regulaization term can be added to the loss function in order to control exploration. This term</span>
|
||||
<span class="sd"> is weighted using the :math:`\beta` value defined by beta_entropy.</span>
|
||||
|
||||
<span class="sd"> :param optimization_epochs: (int)</span>
|
||||
<span class="sd"> For each training phase, the collected dataset will be used for multiple epochs, which are defined by the</span>
|
||||
<span class="sd"> optimization_epochs value.</span>
|
||||
|
||||
<span class="sd"> :param optimization_epochs: (Schedule)</span>
|
||||
<span class="sd"> Can be used to define a schedule over the clipping of the likelihood ratio.</span>
|
||||
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_episodes_in_experience_replay</span> <span class="o">=</span> <span class="mi">1000000</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">=</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">GAE</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">gae_lambda</span> <span class="o">=</span> <span class="mf">0.95</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_kl_regularization</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">clip_likelihood_ratio_using_epsilon</span> <span class="o">=</span> <span class="mf">0.2</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">estimate_state_value_using_gae</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">beta_entropy</span> <span class="o">=</span> <span class="mf">0.01</span> <span class="c1"># should be 0 for mujoco</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">2048</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimization_epochs</span> <span class="o">=</span> <span class="mi">10</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">normalization_stats</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">clipping_decay_schedule</span> <span class="o">=</span> <span class="n">ConstantSchedule</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">ClippedPPOAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">ClippedPPOAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="p">{</span><span class="n">DiscreteActionSpace</span><span class="p">:</span> <span class="n">CategoricalParameters</span><span class="p">(),</span>
|
||||
<span class="n">BoxActionSpace</span><span class="p">:</span> <span class="n">AdditiveNoiseParameters</span><span class="p">()},</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">ClippedPPONetworkParameters</span><span class="p">()})</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.clipped_ppo_agent:ClippedPPOAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Clipped Proximal Policy Optimization - https://arxiv.org/abs/1707.06347</span>
|
||||
<span class="k">class</span> <span class="nc">ClippedPPOAgent</span><span class="p">(</span><span class="n">ActorCriticAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="c1"># signals definition</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Value Loss'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Policy Loss'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_kl_divergence_during_training_process</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">unclipped_grads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Grads (unclipped)'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Value Targets'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">kl_divergence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'KL Divergence'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">likelihood_ratio</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Likelihood Ratio'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">clipped_likelihood_ratio</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Clipped Likelihood Ratio'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">set_session</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sess</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">normalization_stats</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">normalization_stats</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">fill_advantages</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="n">current_state_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">current_state_values</span> <span class="o">=</span> <span class="n">current_state_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">state_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">current_state_values</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate advantages</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">value_targets</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">A_VALUE</span><span class="p">:</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="n">total_returns</span> <span class="o">-</span> <span class="n">current_state_values</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">GAE</span><span class="p">:</span>
|
||||
<span class="c1"># get bootstraps</span>
|
||||
<span class="n">episode_start_idx</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([])</span>
|
||||
<span class="n">value_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([])</span>
|
||||
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">game_over</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()):</span>
|
||||
<span class="k">if</span> <span class="n">game_over</span><span class="p">:</span>
|
||||
<span class="c1"># get advantages for the rollout</span>
|
||||
<span class="n">value_bootstrapping</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,))</span>
|
||||
<span class="n">rollout_state_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">current_state_values</span><span class="p">[</span><span class="n">episode_start_idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">value_bootstrapping</span><span class="p">)</span>
|
||||
|
||||
<span class="n">rollout_advantages</span><span class="p">,</span> <span class="n">gae_based_value_targets</span> <span class="o">=</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">get_general_advantage_estimation_values</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">episode_start_idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
|
||||
<span class="n">rollout_state_values</span><span class="p">)</span>
|
||||
<span class="n">episode_start_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">advantages</span><span class="p">,</span> <span class="n">rollout_advantages</span><span class="p">)</span>
|
||||
<span class="n">value_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">value_targets</span><span class="p">,</span> <span class="n">gae_based_value_targets</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"WARNING: The requested policy gradient rescaler is not available"</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># standardize</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="p">(</span><span class="n">advantages</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">advantages</span><span class="p">))</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">advantages</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">transition</span><span class="p">,</span> <span class="n">advantage</span><span class="p">,</span> <span class="n">value_target</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">transitions</span><span class="p">,</span> <span class="n">advantages</span><span class="p">,</span> <span class="n">value_targets</span><span class="p">):</span>
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">'advantage'</span><span class="p">]</span> <span class="o">=</span> <span class="n">advantage</span>
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">'gae_based_value_target'</span><span class="p">]</span> <span class="o">=</span> <span class="n">value_target</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_advantages</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">advantages</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">epochs</span><span class="p">):</span>
|
||||
<span class="n">batch_results</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
|
||||
<span class="n">batch</span><span class="o">.</span><span class="n">shuffle</span><span class="p">()</span>
|
||||
<span class="n">batch_results</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="s1">'total_loss'</span><span class="p">:</span> <span class="p">[],</span>
|
||||
<span class="s1">'losses'</span><span class="p">:</span> <span class="p">[],</span>
|
||||
<span class="s1">'unclipped_grads'</span><span class="p">:</span> <span class="p">[],</span>
|
||||
<span class="s1">'kl_divergence'</span><span class="p">:</span> <span class="p">[],</span>
|
||||
<span class="s1">'entropy'</span><span class="p">:</span> <span class="p">[]</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="n">fetches</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">kl_divergence</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">entropy</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">likelihood_ratio</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">clipped_likelihood_ratio</span><span class="p">]</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)):</span>
|
||||
<span class="n">start</span> <span class="o">=</span> <span class="n">i</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
<span class="n">end</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
<span class="n">actions</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]</span>
|
||||
<span class="n">gae_based_value_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'gae_based_value_target'</span><span class="p">)[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">actions</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">actions</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get old policy probabilities and distribution</span>
|
||||
|
||||
<span class="c1"># TODO-perf - the target network ("old_policy") is not changing. this can be calculated once for all epochs.</span>
|
||||
<span class="c1"># the shuffling being done, should only be performed on the indices.</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">({</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)</span><span class="o">.</span><span class="n">items</span><span class="p">()})</span>
|
||||
<span class="n">old_policy_distribution</span> <span class="o">=</span> <span class="n">result</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
|
||||
|
||||
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate gradients and apply on both the local policy network and on the global policy network</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">estimate_state_value_using_gae</span><span class="p">:</span>
|
||||
<span class="n">value_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">gae_based_value_targets</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">value_targets</span> <span class="o">=</span> <span class="n">total_returns</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]</span>
|
||||
|
||||
<span class="n">inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">({</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)</span><span class="o">.</span><span class="n">items</span><span class="p">()})</span>
|
||||
<span class="n">inputs</span><span class="p">[</span><span class="s1">'output_1_0'</span><span class="p">]</span> <span class="o">=</span> <span class="n">actions</span>
|
||||
|
||||
<span class="c1"># The old_policy_distribution needs to be represented as a list, because in the event of</span>
|
||||
<span class="c1"># discrete controls, it has just a mean. otherwise, it has both a mean and standard deviation</span>
|
||||
<span class="k">for</span> <span class="n">input_index</span><span class="p">,</span> <span class="nb">input</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">old_policy_distribution</span><span class="p">):</span>
|
||||
<span class="n">inputs</span><span class="p">[</span><span class="s1">'output_1_</span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">input_index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="nb">input</span>
|
||||
|
||||
<span class="c1"># update the clipping decay schedule value</span>
|
||||
<span class="n">inputs</span><span class="p">[</span><span class="s1">'output_1_</span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">old_policy_distribution</span><span class="p">)</span><span class="o">+</span><span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clipping_decay_schedule</span><span class="o">.</span><span class="n">current_value</span>
|
||||
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span><span class="p">,</span> <span class="n">fetch_result</span> <span class="o">=</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span>
|
||||
<span class="n">inputs</span><span class="p">,</span> <span class="p">[</span><span class="n">value_targets</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'advantage'</span><span class="p">)[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]],</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="n">fetches</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="n">batch_results</span><span class="p">[</span><span class="s1">'total_loss'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span>
|
||||
<span class="n">batch_results</span><span class="p">[</span><span class="s1">'losses'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
|
||||
<span class="n">batch_results</span><span class="p">[</span><span class="s1">'unclipped_grads'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">unclipped_grads</span><span class="p">)</span>
|
||||
<span class="n">batch_results</span><span class="p">[</span><span class="s1">'kl_divergence'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">fetch_result</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="n">batch_results</span><span class="p">[</span><span class="s1">'entropy'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">fetch_result</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">unclipped_grads</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">unclipped_grads</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_targets</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">value_targets</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">likelihood_ratio</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">fetch_result</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">clipped_likelihood_ratio</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">fetch_result</span><span class="p">[</span><span class="mi">3</span><span class="p">])</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">batch_results</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="n">batch_results</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">batch_results</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">batch_results</span><span class="p">[</span><span class="s1">'losses'</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">batch_results</span><span class="p">[</span><span class="s1">'losses'</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">batch_results</span><span class="p">[</span><span class="s1">'total_loss'</span><span class="p">])</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">learning_rate_decay_rate</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">curr_learning_rate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">get_variable_value</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">adaptive_learning_rate_scheme</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">curr_learning_rate</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">curr_learning_rate</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">curr_learning_rate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">learning_rate</span>
|
||||
|
||||
<span class="c1"># log training parameters</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_dict</span><span class="p">(</span>
|
||||
<span class="n">OrderedDict</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="s2">"Surrogate loss"</span><span class="p">,</span> <span class="n">batch_results</span><span class="p">[</span><span class="s1">'losses'</span><span class="p">][</span><span class="mi">1</span><span class="p">]),</span>
|
||||
<span class="p">(</span><span class="s2">"KL divergence"</span><span class="p">,</span> <span class="n">batch_results</span><span class="p">[</span><span class="s1">'kl_divergence'</span><span class="p">]),</span>
|
||||
<span class="p">(</span><span class="s2">"Entropy"</span><span class="p">,</span> <span class="n">batch_results</span><span class="p">[</span><span class="s1">'entropy'</span><span class="p">]),</span>
|
||||
<span class="p">(</span><span class="s2">"training epoch"</span><span class="p">,</span> <span class="n">j</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="s2">"learning_rate"</span><span class="p">,</span> <span class="n">curr_learning_rate</span><span class="p">)</span>
|
||||
<span class="p">]),</span>
|
||||
<span class="n">prefix</span><span class="o">=</span><span class="s2">"Policy training"</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_kl_divergence_during_training_process</span> <span class="o">=</span> <span class="n">batch_results</span><span class="p">[</span><span class="s1">'kl_divergence'</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">entropy</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">batch_results</span><span class="p">[</span><span class="s1">'entropy'</span><span class="p">])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">kl_divergence</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">batch_results</span><span class="p">[</span><span class="s1">'kl_divergence'</span><span class="p">])</span>
|
||||
<span class="k">return</span> <span class="n">batch_results</span><span class="p">[</span><span class="s1">'losses'</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">post_training_commands</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="c1"># clean memory</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'clean'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_should_train_helper</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">wait_for_full_episode</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_should_train_helper</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_train</span><span class="p">(</span><span class="n">wait_for_full_episode</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">transitions</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">training_step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">sync</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">fill_advantages</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># take only the requested number of steps</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">):</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span><span class="o">.</span><span class="n">num_steps</span><span class="p">]</span>
|
||||
<span class="n">shuffle</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">train_network</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">optimization_epochs</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">post_training_commands</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="c1"># should be done in order to update the data that has been accumulated * while not playing *</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_log</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="kc">None</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clipping_decay_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">choose_action</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
443
docs/_modules/rl_coach/agents/ddpg_agent.html
Normal file
443
docs/_modules/rl_coach/agents/ddpg_agent.html
Normal file
@@ -0,0 +1,443 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.ddpg_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.ddpg_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.ddpg_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">copy</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.actor_critic_agent</span> <span class="k">import</span> <span class="n">ActorCriticAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">DDPGActorHeadParameters</span><span class="p">,</span> <span class="n">VHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> \
|
||||
<span class="n">AgentParameters</span><span class="p">,</span> <span class="n">EmbedderScheme</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">EnvironmentSteps</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.ou_process</span> <span class="k">import</span> <span class="n">OUProcessParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">BoxActionSpace</span><span class="p">,</span> <span class="n">GoalsSpace</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DDPGCriticNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
|
||||
<span class="s1">'action'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">EmbedderScheme</span><span class="o">.</span><span class="n">Shallow</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">VHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shared_optimizer</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">scale_down_gradients_by_number_of_workers_for_sync_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DDPGActorNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGActorHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0001</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shared_optimizer</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">scale_down_gradients_by_number_of_workers_for_sync_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="DDPGAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/ddpg.html#rl_coach.agents.ddpg_agent.DDPGAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">DDPGAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param num_steps_between_copying_online_weights_to_target: (StepMethod)</span>
|
||||
<span class="sd"> The number of steps between copying the online network weights to the target network weights.</span>
|
||||
|
||||
<span class="sd"> :param rate_for_copying_weights_to_target: (float)</span>
|
||||
<span class="sd"> When copying the online network weights to the target network weights, a soft update will be used, which</span>
|
||||
<span class="sd"> weight the new online network weights by rate_for_copying_weights_to_target</span>
|
||||
|
||||
<span class="sd"> :param num_consecutive_playing_steps: (StepMethod)</span>
|
||||
<span class="sd"> The number of consecutive steps to act between every two training iterations</span>
|
||||
|
||||
<span class="sd"> :param use_target_network_for_evaluation: (bool)</span>
|
||||
<span class="sd"> If set to True, the target network will be used for predicting the actions when choosing actions to act.</span>
|
||||
<span class="sd"> Since the target network weights change more slowly, the predicted actions will be more consistent.</span>
|
||||
|
||||
<span class="sd"> :param action_penalty: (float)</span>
|
||||
<span class="sd"> The amount by which to penalize the network on high action feature (pre-activation) values.</span>
|
||||
<span class="sd"> This can prevent the actions features from saturating the TanH activation function, and therefore prevent the</span>
|
||||
<span class="sd"> gradients from becoming very low.</span>
|
||||
|
||||
<span class="sd"> :param clip_critic_targets: (Tuple[float, float] or None)</span>
|
||||
<span class="sd"> The range to clip the critic target to in order to prevent overestimation of the action values.</span>
|
||||
|
||||
<span class="sd"> :param use_non_zero_discount_for_terminal_states: (bool)</span>
|
||||
<span class="sd"> If set to True, the discount factor will be used for terminal states to bootstrap the next predicted state</span>
|
||||
<span class="sd"> values. If set to False, the terminal states reward will be taken as the target return for the network.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_copying_online_weights_to_target</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.001</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_target_network_for_evaluation</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_penalty</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">clip_critic_targets</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># expected to be a tuple of the form (min_clip_value, max_clip_value) or None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_non_zero_discount_for_terminal_states</span> <span class="o">=</span> <span class="kc">False</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DDPGAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">DDPGAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">OUProcessParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([(</span><span class="s2">"actor"</span><span class="p">,</span> <span class="n">DDPGActorNetworkParameters</span><span class="p">()),</span>
|
||||
<span class="p">(</span><span class="s2">"critic"</span><span class="p">,</span> <span class="n">DDPGCriticNetworkParameters</span><span class="p">())]))</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.ddpg_agent:DDPGAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Deep Deterministic Policy Gradients Network - https://arxiv.org/pdf/1509.02971.pdf</span>
|
||||
<span class="k">class</span> <span class="nc">DDPGAgent</span><span class="p">(</span><span class="n">ActorCriticAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"Q"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">TD_targets_signal</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"TD targets"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_signal</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"actions"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">actor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span>
|
||||
<span class="n">critic</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span>
|
||||
|
||||
<span class="n">actor_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
<span class="n">critic_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># TD error = r + discount*max(q_st_plus_1) - q_st</span>
|
||||
<span class="n">next_actions</span><span class="p">,</span> <span class="n">actions_mean</span> <span class="o">=</span> <span class="n">actor</span><span class="o">.</span><span class="n">parallel_prediction</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="n">actor</span><span class="o">.</span><span class="n">target_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">actor_keys</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">actor_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
|
||||
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">'action'</span><span class="p">]</span> <span class="o">=</span> <span class="n">next_actions</span>
|
||||
<span class="n">q_st_plus_1</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate the bootstrapped TD targets while discounting terminal states according to</span>
|
||||
<span class="c1"># use_non_zero_discount_for_terminal_states</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">use_non_zero_discount_for_terminal_states</span><span class="p">:</span>
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">q_st_plus_1</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> \
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">q_st_plus_1</span>
|
||||
|
||||
<span class="c1"># clip the TD targets to prevent overestimation errors</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clip_critic_targets</span><span class="p">:</span>
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clip_critic_targets</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">TD_targets_signal</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get the gradients of the critic output with respect to the action</span>
|
||||
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
|
||||
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">'action'</span><span class="p">]</span> <span class="o">=</span> <span class="n">actions_mean</span>
|
||||
<span class="n">action_gradients</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">,</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">critic</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">gradients_wrt_inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="s1">'action'</span><span class="p">])</span>
|
||||
|
||||
<span class="c1"># train the critic</span>
|
||||
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
|
||||
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">'action'</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">,</span> <span class="n">TD_targets</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># apply the gradients from the critic to the actor</span>
|
||||
<span class="n">initial_feed_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">gradients_weights_ph</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span> <span class="o">-</span><span class="n">action_gradients</span><span class="p">}</span>
|
||||
<span class="n">gradients</span> <span class="o">=</span> <span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">actor_keys</span><span class="p">),</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">weighted_gradients</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="n">initial_feed_dict</span><span class="o">=</span><span class="n">initial_feed_dict</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">actor</span><span class="o">.</span><span class="n">has_global</span><span class="p">:</span>
|
||||
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
|
||||
<span class="n">actor</span><span class="o">.</span><span class="n">update_online_network</span><span class="p">()</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_online_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">Agent</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">GoalsSpace</span><span class="p">)):</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"DDPG works only for continuous control problems"</span><span class="p">)</span>
|
||||
<span class="c1"># convert to batch so we can run it through the network</span>
|
||||
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">'actor'</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">use_target_network_for_evaluation</span><span class="p">:</span>
|
||||
<span class="n">actor_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">actor_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
|
||||
|
||||
<span class="n">action_values</span> <span class="o">=</span> <span class="n">actor_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_signal</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get q value</span>
|
||||
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">'critic'</span><span class="p">)</span>
|
||||
<span class="n">action_batch</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">action</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">action</span><span class="p">)</span> <span class="o">!=</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
|
||||
<span class="n">action_batch</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="n">action</span><span class="p">]])</span>
|
||||
<span class="n">tf_input_state</span><span class="p">[</span><span class="s1">'action'</span><span class="p">]</span> <span class="o">=</span> <span class="n">action_batch</span>
|
||||
<span class="n">q_value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q_value</span><span class="p">)</span>
|
||||
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span>
|
||||
<span class="n">action_value</span><span class="o">=</span><span class="n">q_value</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">action_info</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
475
docs/_modules/rl_coach/agents/dfp_agent.html
Normal file
475
docs/_modules/rl_coach/agents/dfp_agent.html
Normal file
@@ -0,0 +1,475 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.dfp_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.dfp_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.dfp_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">copy</span>
|
||||
<span class="kn">from</span> <span class="nn">enum</span> <span class="k">import</span> <span class="n">Enum</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">MeasurementsPredictionHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.tensorflow_components.layers</span> <span class="k">import</span> <span class="n">Conv2d</span><span class="p">,</span> <span class="n">Dense</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">AgentParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> \
|
||||
<span class="n">MiddlewareScheme</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">RunPhase</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.e_greedy</span> <span class="k">import</span> <span class="n">EGreedyParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.memory</span> <span class="k">import</span> <span class="n">MemoryGranularity</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">SpacesDefinition</span><span class="p">,</span> <span class="n">VectorObservationSpace</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">HandlingTargetsAfterEpisodeEnd</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
|
||||
<span class="n">LastStep</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">NAN</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DFPNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'leaky_relu'</span><span class="p">),</span>
|
||||
<span class="s1">'measurements'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'leaky_relu'</span><span class="p">),</span>
|
||||
<span class="s1">'goal'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'leaky_relu'</span><span class="p">)}</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="p">[</span><span class="s1">'observation'</span><span class="p">]</span><span class="o">.</span><span class="n">scheme</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">Conv2d</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">Conv2d</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
|
||||
<span class="n">Conv2d</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||||
<span class="n">Dense</span><span class="p">(</span><span class="mi">512</span><span class="p">),</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">scheme</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">),</span>
|
||||
<span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">),</span>
|
||||
<span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">),</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="p">[</span><span class="s1">'goal'</span><span class="p">]</span><span class="o">.</span><span class="n">scheme</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">),</span>
|
||||
<span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">),</span>
|
||||
<span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">),</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'leaky_relu'</span><span class="p">,</span>
|
||||
<span class="n">scheme</span><span class="o">=</span><span class="n">MiddlewareScheme</span><span class="o">.</span><span class="n">Empty</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">MeasurementsPredictionHeadParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'leaky_relu'</span><span class="p">)]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta1</span> <span class="o">=</span> <span class="mf">0.95</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DFPMemoryParameters</span><span class="p">(</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">20000</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shared_memory</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="DFPAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/other/dfp.html#rl_coach.agents.dfp_agent.DFPAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">DFPAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param num_predicted_steps_ahead: (int)</span>
|
||||
<span class="sd"> Number of future steps to predict measurements for. The future steps won't be sequential, but rather jump</span>
|
||||
<span class="sd"> in multiples of 2. For example, if num_predicted_steps_ahead = 3, then the steps will be: t+1, t+2, t+4</span>
|
||||
|
||||
<span class="sd"> :param goal_vector: (List[float])</span>
|
||||
<span class="sd"> The goal vector will weight each of the measurements to form an optimization goal. The vector should have</span>
|
||||
<span class="sd"> the same length as the number of measurements, and it will be vector multiplied by the measurements.</span>
|
||||
<span class="sd"> Positive values correspond to trying to maximize the particular measurement, and negative values</span>
|
||||
<span class="sd"> correspond to trying to minimize the particular measurement.</span>
|
||||
|
||||
<span class="sd"> :param future_measurements_weights: (List[float])</span>
|
||||
<span class="sd"> The future_measurements_weights weight the contribution of each of the predicted timesteps to the optimization</span>
|
||||
<span class="sd"> goal. For example, if there are 6 steps predicted ahead, and a future_measurements_weights vector with 3 values,</span>
|
||||
<span class="sd"> then only the 3 last timesteps will be taken into account, according to the weights in the</span>
|
||||
<span class="sd"> future_measurements_weights vector.</span>
|
||||
|
||||
<span class="sd"> :param use_accumulated_reward_as_measurement: (bool)</span>
|
||||
<span class="sd"> If set to True, the accumulated reward from the beginning of the episode will be added as a measurement to</span>
|
||||
<span class="sd"> the measurements vector in the state. This van be useful in environments where the given measurements don't</span>
|
||||
<span class="sd"> include enough information for the particular goal the agent should achieve.</span>
|
||||
|
||||
<span class="sd"> :param handling_targets_after_episode_end: (HandlingTargetsAfterEpisodeEnd)</span>
|
||||
<span class="sd"> Dictates how to handle measurements that are outside the episode length.</span>
|
||||
|
||||
<span class="sd"> :param scale_measurements_targets: (Dict[str, float])</span>
|
||||
<span class="sd"> Allows rescaling the values of each of the measurements available. This van be useful when the measurements</span>
|
||||
<span class="sd"> have a different scale and you want to normalize them to the same scale.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_predicted_steps_ahead</span> <span class="o">=</span> <span class="mi">6</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">goal_vector</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">future_measurements_weights</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_accumulated_reward_as_measurement</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">handling_targets_after_episode_end</span> <span class="o">=</span> <span class="n">HandlingTargetsAfterEpisodeEnd</span><span class="o">.</span><span class="n">NAN</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">scale_measurements_targets</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DFPAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">DFPAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">EGreedyParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">DFPMemoryParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">DFPNetworkParameters</span><span class="p">()})</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.dfp_agent:DFPAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Direct Future Prediction Agent - http://vladlen.info/papers/learning-to-act.pdf</span>
|
||||
<span class="k">class</span> <span class="nc">DFPAgent</span><span class="p">(</span><span class="n">Agent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_goal</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">goal_vector</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">target_measurements_scale_factors</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="n">network_inputs</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)</span>
|
||||
<span class="n">network_inputs</span><span class="p">[</span><span class="s1">'goal'</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_goal</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get the current outputs of the network</span>
|
||||
<span class="n">targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">network_inputs</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># change the targets for the taken actions</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="n">targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">'future_measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">network_inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
|
||||
<span class="c1"># predict the future measurements</span>
|
||||
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">)</span>
|
||||
<span class="n">tf_input_state</span><span class="p">[</span><span class="s1">'goal'</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_goal</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">measurements_future_prediction</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">action_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">actions</span><span class="p">))</span>
|
||||
<span class="n">num_steps_used_for_objective</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">future_measurements_weights</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate the score of each action by multiplying it's future measurements with the goal vector</span>
|
||||
<span class="k">for</span> <span class="n">action_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">actions</span><span class="p">)):</span>
|
||||
<span class="n">action_measurements</span> <span class="o">=</span> <span class="n">measurements_future_prediction</span><span class="p">[</span><span class="n">action_idx</span><span class="p">]</span>
|
||||
<span class="n">action_measurements</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">action_measurements</span><span class="p">,</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_predicted_steps_ahead</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
|
||||
<span class="n">future_steps_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">action_measurements</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_goal</span><span class="p">)</span>
|
||||
<span class="n">action_values</span><span class="p">[</span><span class="n">action_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">future_steps_values</span><span class="p">[</span><span class="o">-</span><span class="n">num_steps_used_for_objective</span><span class="p">:],</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">future_measurements_weights</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">action_values</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="c1"># choose action according to the exploration policy and the current phase (evaluating or training the agent)</span>
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">action_values</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">action_values</span> <span class="o">=</span> <span class="n">action_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span> <span class="n">action_value</span><span class="o">=</span><span class="n">action_values</span><span class="p">[</span><span class="n">action</span><span class="p">])</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">action_info</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">set_environment_parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">spaces</span><span class="p">:</span> <span class="n">SpacesDefinition</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spaces</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">spaces</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">goal</span> <span class="o">=</span> <span class="n">VectorObservationSpace</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
|
||||
<span class="n">measurements_names</span><span class="o">=</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">measurements_names</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># if the user has filled some scale values, check that he got the names right</span>
|
||||
<span class="k">if</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">measurements_names</span><span class="p">)</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">scale_measurements_targets</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="o">!=</span>\
|
||||
<span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">scale_measurements_targets</span><span class="o">.</span><span class="n">keys</span><span class="p">()):</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Some of the keys in parameter scale_measurements_targets (</span><span class="si">{}</span><span class="s2">) are not defined in "</span>
|
||||
<span class="s2">"the measurements space </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">scale_measurements_targets</span><span class="o">.</span><span class="n">keys</span><span class="p">(),</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">measurements_names</span><span class="p">))</span>
|
||||
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">set_environment_parameters</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># the below is done after calling the base class method, as it might add accumulated reward as a measurement</span>
|
||||
|
||||
<span class="c1"># fill out the missing measurements scale factors</span>
|
||||
<span class="k">for</span> <span class="n">measurement_name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">measurements_names</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">measurement_name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">scale_measurements_targets</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">scale_measurements_targets</span><span class="p">[</span><span class="n">measurement_name</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">target_measurements_scale_factors</span> <span class="o">=</span> \
|
||||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">scale_measurements_targets</span><span class="p">[</span><span class="n">measurement_name</span><span class="p">]</span> <span class="k">for</span> <span class="n">measurement_name</span> <span class="ow">in</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">measurements_names</span><span class="p">])</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">handle_episode_ended</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="n">last_episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">in</span> <span class="p">[</span><span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">,</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span><span class="p">]</span> <span class="ow">and</span> <span class="n">last_episode</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_update_measurements_targets</span><span class="p">(</span><span class="n">last_episode</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_predicted_steps_ahead</span><span class="p">)</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">handle_episode_ended</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_update_measurements_targets</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode</span><span class="p">,</span> <span class="n">num_steps</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="s1">'measurements'</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">episode</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">state</span> <span class="ow">or</span> <span class="n">episode</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span> <span class="o">==</span> <span class="p">[]:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Measurements are not present in the transitions of the last episode played. "</span><span class="p">)</span>
|
||||
<span class="n">measurements_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">for</span> <span class="n">transition_idx</span><span class="p">,</span> <span class="n">transition</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">episode</span><span class="o">.</span><span class="n">transitions</span><span class="p">):</span>
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">'future_measurements'</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_steps</span><span class="p">,</span> <span class="n">measurements_size</span><span class="p">))</span>
|
||||
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_steps</span><span class="p">):</span>
|
||||
<span class="n">offset_idx</span> <span class="o">=</span> <span class="n">transition_idx</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">**</span> <span class="n">step</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">offset_idx</span> <span class="o">>=</span> <span class="n">episode</span><span class="o">.</span><span class="n">length</span><span class="p">():</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">handling_targets_after_episode_end</span> <span class="o">==</span> <span class="n">HandlingTargetsAfterEpisodeEnd</span><span class="o">.</span><span class="n">NAN</span><span class="p">:</span>
|
||||
<span class="c1"># the special MSE loss will ignore those entries so that the gradient will be 0 for these</span>
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">'future_measurements'</span><span class="p">][</span><span class="n">step</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span>
|
||||
<span class="k">continue</span>
|
||||
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">handling_targets_after_episode_end</span> <span class="o">==</span> <span class="n">HandlingTargetsAfterEpisodeEnd</span><span class="o">.</span><span class="n">LastStep</span><span class="p">:</span>
|
||||
<span class="n">offset_idx</span> <span class="o">=</span> <span class="o">-</span> <span class="mi">1</span>
|
||||
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">'future_measurements'</span><span class="p">][</span><span class="n">step</span><span class="p">]</span> <span class="o">=</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">target_measurements_scale_factors</span> <span class="o">*</span> \
|
||||
<span class="p">(</span><span class="n">episode</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">offset_idx</span><span class="p">]</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span> <span class="o">-</span> <span class="n">transition</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">])</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
326
docs/_modules/rl_coach/agents/dqn_agent.html
Normal file
326
docs/_modules/rl_coach/agents/dqn_agent.html
Normal file
@@ -0,0 +1,326 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.dqn_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.dqn_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.dqn_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.value_optimization_agent</span> <span class="k">import</span> <span class="n">ValueOptimizationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">QHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AgentParameters</span><span class="p">,</span> \
|
||||
<span class="n">MiddlewareScheme</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">EnvironmentSteps</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.e_greedy</span> <span class="k">import</span> <span class="n">EGreedyParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.experience_replay</span> <span class="k">import</span> <span class="n">ExperienceReplayParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.schedules</span> <span class="k">import</span> <span class="n">LinearSchedule</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="DQNAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/dqn.html#rl_coach.agents.dqn_agent.DQNAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">DQNAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_copying_online_weights_to_target</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">10000</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">discount</span> <span class="o">=</span> <span class="mf">0.99</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DQNNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">MiddlewareScheme</span><span class="o">.</span><span class="n">Medium</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">QHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DQNAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">DQNAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">EGreedyParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">ExperienceReplayParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">DQNNetworkParameters</span><span class="p">()})</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span><span class="o">.</span><span class="n">epsilon_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span><span class="o">.</span><span class="n">evaluation_epsilon</span> <span class="o">=</span> <span class="mf">0.05</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.dqn_agent:DQNAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Deep Q Network - https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf</span>
|
||||
<div class="viewcode-block" id="DQNAgent"><a class="viewcode-back" href="../../../test.html#rl_coach.agents.dqn_agent.DQNAgent">[docs]</a><span class="k">class</span> <span class="nc">DQNAgent</span><span class="p">(</span><span class="n">ValueOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
|
||||
<div class="viewcode-block" id="DQNAgent.learn_from_batch"><a class="viewcode-back" href="../../../test.html#rl_coach.agents.dqn_agent.DQNAgent.learn_from_batch">[docs]</a> <span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># for the action we actually took, the error is:</span>
|
||||
<span class="c1"># TD error = r + discount*max(q_st_plus_1) - q_st</span>
|
||||
<span class="c1"># # for all other actions, the error is 0</span>
|
||||
<span class="n">q_st_plus_1</span><span class="p">,</span> <span class="n">TD_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">parallel_prediction</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="c1"># only update the action that we have actually done in this transition</span>
|
||||
<span class="n">TD_errors</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="n">new_target</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span>\
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">q_st_plus_1</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">TD_errors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">new_target</span> <span class="o">-</span> <span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]))</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">new_target</span>
|
||||
|
||||
<span class="c1"># update errors in prioritized replay buffer</span>
|
||||
<span class="n">importance_weights</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_transition_priorities_and_get_weights</span><span class="p">(</span><span class="n">TD_errors</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">TD_targets</span><span class="p">,</span>
|
||||
<span class="n">importance_weights</span><span class="o">=</span><span class="n">importance_weights</span><span class="p">)</span>
|
||||
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span></div></div>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
306
docs/_modules/rl_coach/agents/mmc_agent.html
Normal file
306
docs/_modules/rl_coach/agents/mmc_agent.html
Normal file
@@ -0,0 +1,306 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.mmc_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.mmc_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.mmc_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.dqn_agent</span> <span class="k">import</span> <span class="n">DQNAgentParameters</span><span class="p">,</span> <span class="n">DQNAlgorithmParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.value_optimization_agent</span> <span class="k">import</span> <span class="n">ValueOptimizationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="MixedMonteCarloAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/mmc.html#rl_coach.agents.mmc_agent.MixedMonteCarloAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">MixedMonteCarloAlgorithmParameters</span><span class="p">(</span><span class="n">DQNAlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param monte_carlo_mixing_rate: (float)</span>
|
||||
<span class="sd"> The mixing rate is used for setting the amount of monte carlo estimate (full return) that will be mixes into</span>
|
||||
<span class="sd"> the single-step bootstrapped targets.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">monte_carlo_mixing_rate</span> <span class="o">=</span> <span class="mf">0.1</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">MixedMonteCarloAgentParameters</span><span class="p">(</span><span class="n">DQNAgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">algorithm</span> <span class="o">=</span> <span class="n">MixedMonteCarloAlgorithmParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="o">=</span> <span class="n">EpisodicExperienceReplayParameters</span><span class="p">()</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.mmc_agent:MixedMonteCarloAgent'</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">MixedMonteCarloAgent</span><span class="p">(</span><span class="n">ValueOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mixing_rate</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">monte_carlo_mixing_rate</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># for the 1-step, we use the double-dqn target. hence actions are taken greedily according to the online network</span>
|
||||
<span class="n">selected_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)),</span> <span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># TD_targets are initialized with the current prediction so that we will</span>
|
||||
<span class="c1"># only update the action that we have actually done in this transition</span>
|
||||
<span class="n">q_st_plus_1</span><span class="p">,</span> <span class="n">TD_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">parallel_prediction</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="n">one_step_target</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> \
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> \
|
||||
<span class="n">q_st_plus_1</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">selected_actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
|
||||
<span class="n">monte_carlo_target</span> <span class="o">=</span> <span class="n">total_returns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixing_rate</span><span class="p">)</span> <span class="o">*</span> <span class="n">one_step_target</span> <span class="o">+</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mixing_rate</span> <span class="o">*</span> <span class="n">monte_carlo_target</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">TD_targets</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
373
docs/_modules/rl_coach/agents/n_step_q_agent.html
Normal file
373
docs/_modules/rl_coach/agents/n_step_q_agent.html
Normal file
@@ -0,0 +1,373 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.n_step_q_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.n_step_q_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.n_step_q_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.policy_optimization_agent</span> <span class="k">import</span> <span class="n">PolicyOptimizationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.value_optimization_agent</span> <span class="k">import</span> <span class="n">ValueOptimizationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">QHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">AgentParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">EnvironmentSteps</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.e_greedy</span> <span class="k">import</span> <span class="n">EGreedyParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.single_episode_buffer</span> <span class="k">import</span> <span class="n">SingleEpisodeBufferParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">last_sample</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NStepQNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">QHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shared_optimizer</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="NStepQAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/n_step.html#rl_coach.agents.n_step_q_agent.NStepQAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">NStepQAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param num_steps_between_copying_online_weights_to_target: (StepMethod)</span>
|
||||
<span class="sd"> The number of steps between copying the online network weights to the target network weights.</span>
|
||||
|
||||
<span class="sd"> :param apply_gradients_every_x_episodes: (int)</span>
|
||||
<span class="sd"> The number of episodes between applying the accumulated gradients to the network. After every</span>
|
||||
<span class="sd"> num_steps_between_gradient_updates steps, the agent will calculate the gradients for the collected data,</span>
|
||||
<span class="sd"> it will then accumulate it in internal accumulators, and will only apply them to the network once in every</span>
|
||||
<span class="sd"> apply_gradients_every_x_episodes episodes.</span>
|
||||
|
||||
<span class="sd"> :param num_steps_between_gradient_updates: (int)</span>
|
||||
<span class="sd"> The number of steps between calculating gradients for the collected data. In the A3C paper, this parameter is</span>
|
||||
<span class="sd"> called t_max. Since this algorithm is on-policy, only the steps collected between each two gradient calculations</span>
|
||||
<span class="sd"> are used in the batch.</span>
|
||||
|
||||
<span class="sd"> :param targets_horizon: (str)</span>
|
||||
<span class="sd"> Should be either 'N-Step' or '1-Step', and defines the length for which to bootstrap the network values over.</span>
|
||||
<span class="sd"> Essentially, 1-Step follows the regular 1 step bootstrapping Q learning update. For more information,</span>
|
||||
<span class="sd"> please refer to the original paper (https://arxiv.org/abs/1602.01783)</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_copying_online_weights_to_target</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">10000</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_every_x_episodes</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_gradient_updates</span> <span class="o">=</span> <span class="mi">5</span> <span class="c1"># this is called t_max in all the papers</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">targets_horizon</span> <span class="o">=</span> <span class="s1">'N-Step'</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NStepQAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">NStepQAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">EGreedyParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">SingleEpisodeBufferParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">NStepQNetworkParameters</span><span class="p">()})</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.n_step_q_agent:NStepQAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># N Step Q Learning Agent - https://arxiv.org/abs/1602.01783</span>
|
||||
<span class="k">class</span> <span class="nc">NStepQAgent</span><span class="p">(</span><span class="n">ValueOptimizationAgent</span><span class="p">,</span> <span class="n">PolicyOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_gradient_update_step_idx</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Q Values'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Value Loss'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="c1"># batch contains a list of episodes to learn from</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># get the values for the current states</span>
|
||||
<span class="n">state_value_head_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># the targets for the state value estimator</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">targets_horizon</span> <span class="o">==</span> <span class="s1">'1-Step'</span><span class="p">:</span>
|
||||
<span class="c1"># 1-Step Q learning</span>
|
||||
<span class="n">q_st_plus_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">)):</span>
|
||||
<span class="n">state_value_head_targets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> \
|
||||
<span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> \
|
||||
<span class="o">+</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">q_st_plus_1</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">targets_horizon</span> <span class="o">==</span> <span class="s1">'N-Step'</span><span class="p">:</span>
|
||||
<span class="c1"># N-Step Q learning</span>
|
||||
<span class="k">if</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="n">R</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">R</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">last_sample</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))))</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">)):</span>
|
||||
<span class="n">R</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">R</span>
|
||||
<span class="n">state_value_head_targets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">R</span>
|
||||
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">assert</span> <span class="kc">True</span><span class="p">,</span> <span class="s1">'The available values for targets_horizon are: 1-Step, N-Step'</span>
|
||||
|
||||
<span class="c1"># train</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="p">[</span><span class="n">state_value_head_targets</span><span class="p">])</span>
|
||||
|
||||
<span class="c1"># logging</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">losses</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="c1"># update the target network of every network that has a target network</span>
|
||||
<span class="k">if</span> <span class="nb">any</span><span class="p">([</span><span class="n">network</span><span class="o">.</span><span class="n">has_target</span> <span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">()])</span> \
|
||||
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_update_online_weights_to_target</span><span class="p">():</span>
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">update_target_network</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Update Target Network'</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Update Target Network'</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">PolicyOptimizationAgent</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
354
docs/_modules/rl_coach/agents/naf_agent.html
Normal file
354
docs/_modules/rl_coach/agents/naf_agent.html
Normal file
@@ -0,0 +1,354 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.naf_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.naf_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.naf_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.value_optimization_agent</span> <span class="k">import</span> <span class="n">ValueOptimizationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">NAFHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">AgentParameters</span><span class="p">,</span> \
|
||||
<span class="n">NetworkParameters</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">EnvironmentSteps</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.ou_process</span> <span class="k">import</span> <span class="n">OUProcessParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">BoxActionSpace</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NAFNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">NAFHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="NAFAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/naf.html#rl_coach.agents.naf_agent.NAFAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">NAFAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span> <span class="o">=</span> <span class="mi">5</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_copying_online_weights_to_target</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.001</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NAFAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">NAFAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">OUProcessParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">NAFNetworkParameters</span><span class="p">()})</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.naf_agent:NAFAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Normalized Advantage Functions - https://arxiv.org/pdf/1603.00748.pdf</span>
|
||||
<span class="k">class</span> <span class="nc">NAFAgent</span><span class="p">(</span><span class="n">ValueOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">l_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"L"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">a_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"Advantage"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mu_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"Action"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">v_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"V"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">TD_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"TD targets"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># TD error = r + discount*v_st_plus_1 - q_st</span>
|
||||
<span class="n">v_st_plus_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">V</span><span class="p">,</span>
|
||||
<span class="n">squeeze_output</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> \
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">(),</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">v_st_plus_1</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">TD_targets</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">)</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">({</span><span class="o">**</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span>
|
||||
<span class="s1">'output_0_0'</span><span class="p">:</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="p">},</span> <span class="n">TD_targets</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">)</span> <span class="o">!=</span> <span class="n">BoxActionSpace</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'NAF works only for continuous control problems'</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># convert to batch so we can run it through the network</span>
|
||||
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">)</span>
|
||||
<span class="n">naf_head</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">action_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">naf_head</span><span class="o">.</span><span class="n">mu</span><span class="p">,</span>
|
||||
<span class="n">squeeze_output</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get the actual action to use</span>
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get the internal values for logging</span>
|
||||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">naf_head</span><span class="o">.</span><span class="n">mu</span><span class="p">,</span> <span class="n">naf_head</span><span class="o">.</span><span class="n">Q</span><span class="p">,</span> <span class="n">naf_head</span><span class="o">.</span><span class="n">L</span><span class="p">,</span> <span class="n">naf_head</span><span class="o">.</span><span class="n">A</span><span class="p">,</span> <span class="n">naf_head</span><span class="o">.</span><span class="n">V</span><span class="p">]</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="p">{</span><span class="o">**</span><span class="n">tf_input_state</span><span class="p">,</span> <span class="s1">'output_0_0'</span><span class="p">:</span> <span class="n">action_values</span><span class="p">},</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">mu</span><span class="p">,</span> <span class="n">Q</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">result</span>
|
||||
|
||||
<span class="c1"># store the q values statistics for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">Q</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">l_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">L</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">a_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mu_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">mu</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">v_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">V</span><span class="p">)</span>
|
||||
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span> <span class="n">action_value</span><span class="o">=</span><span class="n">Q</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">action_info</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
435
docs/_modules/rl_coach/agents/nec_agent.html
Normal file
435
docs/_modules/rl_coach/agents/nec_agent.html
Normal file
@@ -0,0 +1,435 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.nec_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.nec_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.nec_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">os</span>
|
||||
<span class="kn">import</span> <span class="nn">pickle</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.value_optimization_agent</span> <span class="k">import</span> <span class="n">ValueOptimizationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">DNDQHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AgentParameters</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">Episode</span><span class="p">,</span> <span class="n">StateType</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.e_greedy</span> <span class="k">import</span> <span class="n">EGreedyParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span><span class="p">,</span> <span class="n">MemoryGranularity</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.schedules</span> <span class="k">import</span> <span class="n">ConstantSchedule</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NECNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DNDQHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="NECAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/nec.html#rl_coach.agents.nec_agent.NECAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">NECAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param dnd_size: (int)</span>
|
||||
<span class="sd"> Defines the number of transitions that will be stored in each one of the DNDs. Note that the total number</span>
|
||||
<span class="sd"> of transitions that will be stored is dnd_size x num_actions.</span>
|
||||
|
||||
<span class="sd"> :param l2_norm_added_delta: (float)</span>
|
||||
<span class="sd"> A small value that will be added when calculating the weight of each of the DND entries. This follows the</span>
|
||||
<span class="sd"> :math:`\delta` patameter defined in the paper.</span>
|
||||
|
||||
<span class="sd"> :param new_value_shift_coefficient: (float)</span>
|
||||
<span class="sd"> In the case where a ew embedding that was added to the DND was already present, the value that will be stored</span>
|
||||
<span class="sd"> in the DND is a mix between the existing value and the new value. The mix rate is defined by</span>
|
||||
<span class="sd"> new_value_shift_coefficient.</span>
|
||||
|
||||
<span class="sd"> :param number_of_knn: (int)</span>
|
||||
<span class="sd"> The number of neighbors that will be retrieved for each DND query.</span>
|
||||
|
||||
<span class="sd"> :param DND_key_error_threshold: (float)</span>
|
||||
<span class="sd"> When the DND is queried for a specific embedding, this threshold will be used to determine if the embedding</span>
|
||||
<span class="sd"> exists in the DND, since exact matches of embeddings are very rare.</span>
|
||||
|
||||
<span class="sd"> :param propagate_updates_to_DND: (bool)</span>
|
||||
<span class="sd"> If set to True, when the gradients of the network will be calculated, the gradients will also be</span>
|
||||
<span class="sd"> backpropagated through the keys of the DND. The keys will then be updated as well, as if they were regular</span>
|
||||
<span class="sd"> network weights.</span>
|
||||
|
||||
<span class="sd"> :param n_step: (int)</span>
|
||||
<span class="sd"> The bootstrap length that will be used when calculating the state values to store in the DND.</span>
|
||||
|
||||
<span class="sd"> :param bootstrap_total_return_from_old_policy: (bool)</span>
|
||||
<span class="sd"> If set to True, the bootstrap that will be used to calculate each state-action value, is the network value</span>
|
||||
<span class="sd"> when the state was first seen, and not the latest, most up-to-date network value.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">dnd_size</span> <span class="o">=</span> <span class="mi">500000</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">l2_norm_added_delta</span> <span class="o">=</span> <span class="mf">0.001</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">new_value_shift_coefficient</span> <span class="o">=</span> <span class="mf">0.1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">number_of_knn</span> <span class="o">=</span> <span class="mi">50</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">DND_key_error_threshold</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">propagate_updates_to_DND</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">n_step</span> <span class="o">=</span> <span class="mi">100</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">bootstrap_total_return_from_old_policy</span> <span class="o">=</span> <span class="kc">True</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NECMemoryParameters</span><span class="p">(</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">100000</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NECAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">NECAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">EGreedyParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">NECMemoryParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">NECNetworkParameters</span><span class="p">()})</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span><span class="o">.</span><span class="n">epsilon_schedule</span> <span class="o">=</span> <span class="n">ConstantSchedule</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span><span class="o">.</span><span class="n">evaluation_epsilon</span> <span class="o">=</span> <span class="mf">0.01</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.nec_agent:NECAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Neural Episodic Control - https://arxiv.org/pdf/1703.01988.pdf</span>
|
||||
<span class="k">class</span> <span class="nc">NECAgent</span><span class="p">(</span><span class="n">ValueOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_started</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span> <span class="o">=</span> \
|
||||
<span class="n">Episode</span><span class="p">(</span><span class="n">discount</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span><span class="p">,</span>
|
||||
<span class="n">n_step</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">n_step</span><span class="p">,</span>
|
||||
<span class="n">bootstrap_total_return_from_old_policy</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">bootstrap_total_return_from_old_policy</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">DND</span><span class="o">.</span><span class="n">has_enough_entries</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">number_of_knn</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="mi">0</span><span class="p">,</span> <span class="p">[],</span> <span class="mi">0</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_started</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_started</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"Finished collecting initial entries in DND. Starting to train network..."</span><span class="p">)</span>
|
||||
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="n">bootstrapped_return_from_old_policy</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
<span class="c1"># only update the action that we have actually done in this transition</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">bootstrapped_return_from_old_policy</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># set the gradients to fetch for the DND update</span>
|
||||
<span class="n">fetches</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">head</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">propagate_updates_to_DND</span><span class="p">:</span>
|
||||
<span class="n">fetches</span> <span class="o">=</span> <span class="p">[</span><span class="n">head</span><span class="o">.</span><span class="n">dnd_embeddings_grad</span><span class="p">,</span> <span class="n">head</span><span class="o">.</span><span class="n">dnd_values_grad</span><span class="p">,</span> <span class="n">head</span><span class="o">.</span><span class="n">dnd_indices</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># train the neural network</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">TD_targets</span><span class="p">,</span> <span class="n">fetches</span><span class="p">)</span>
|
||||
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># update the DND keys and values using the extracted gradients</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">propagate_updates_to_DND</span><span class="p">:</span>
|
||||
<span class="n">embedding_gradients</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">value_gradients</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">2</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">head</span><span class="o">.</span><span class="n">DND</span><span class="o">.</span><span class="n">update_keys_and_values</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">(),</span> <span class="n">embedding_gradients</span><span class="p">,</span> <span class="n">value_gradients</span><span class="p">,</span> <span class="n">indices</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">act</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span><span class="p">:</span>
|
||||
<span class="c1"># get embedding in heatup (otherwise we get it through get_prediction)</span>
|
||||
<span class="n">embedding</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">),</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">state_embedding</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">act</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="c1"># we need to store the state embeddings regardless if the action is random or not</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
|
||||
<span class="c1"># get the actions q values and the state embedding</span>
|
||||
<span class="n">embedding</span><span class="p">,</span> <span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">),</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">state_embedding</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">output</span><span class="p">]</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
|
||||
<span class="c1"># store the state embedding for inserting it to the DND later</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">return</span> <span class="n">actions_q_values</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">reset_internal_state</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">reset_internal_state</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span> <span class="o">=</span> \
|
||||
<span class="n">Episode</span><span class="p">(</span><span class="n">discount</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span><span class="p">,</span>
|
||||
<span class="n">n_step</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">n_step</span><span class="p">,</span>
|
||||
<span class="n">bootstrap_total_return_from_old_policy</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">bootstrap_total_return_from_old_policy</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">handle_episode_ended</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">handle_episode_ended</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># get the last full episode that we have collected</span>
|
||||
<span class="n">episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'get_last_complete_episode'</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">episode</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
|
||||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="p">)</span> <span class="o">==</span> <span class="n">episode</span><span class="o">.</span><span class="n">length</span><span class="p">()</span>
|
||||
<span class="n">discounted_rewards</span> <span class="o">=</span> <span class="n">episode</span><span class="o">.</span><span class="n">get_transitions_attribute</span><span class="p">(</span><span class="s1">'n_step_discounted_rewards'</span><span class="p">)</span>
|
||||
<span class="n">actions</span> <span class="o">=</span> <span class="n">episode</span><span class="o">.</span><span class="n">get_transitions_attribute</span><span class="p">(</span><span class="s1">'action'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">DND</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="p">,</span>
|
||||
<span class="n">actions</span><span class="p">,</span> <span class="n">discounted_rewards</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">save_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">checkpoint_id</span><span class="p">):</span>
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">checkpoint_save_dir</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">checkpoint_id</span><span class="p">)</span> <span class="o">+</span> <span class="s1">'.dnd'</span><span class="p">),</span> <span class="s1">'wb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||||
<span class="n">pickle</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">DND</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">pickle</span><span class="o">.</span><span class="n">HIGHEST_PROTOCOL</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
334
docs/_modules/rl_coach/agents/pal_agent.html
Normal file
334
docs/_modules/rl_coach/agents/pal_agent.html
Normal file
@@ -0,0 +1,334 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.pal_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.pal_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.pal_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.dqn_agent</span> <span class="k">import</span> <span class="n">DQNAgentParameters</span><span class="p">,</span> <span class="n">DQNAlgorithmParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.value_optimization_agent</span> <span class="k">import</span> <span class="n">ValueOptimizationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="PALAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/pal.html#rl_coach.agents.pal_agent.PALAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">PALAlgorithmParameters</span><span class="p">(</span><span class="n">DQNAlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param pal_alpha: (float)</span>
|
||||
<span class="sd"> A factor that weights the amount by which the advantage learning update will be taken into account.</span>
|
||||
|
||||
<span class="sd"> :param persistent_advantage_learning: (bool)</span>
|
||||
<span class="sd"> If set to True, the persistent mode of advantage learning will be used, which encourages the agent to take</span>
|
||||
<span class="sd"> the same actions one after the other instead of changing actions.</span>
|
||||
|
||||
<span class="sd"> :param monte_carlo_mixing_rate: (float)</span>
|
||||
<span class="sd"> The amount of monte carlo values to mix into the targets of the network. The monte carlo values are just the</span>
|
||||
<span class="sd"> total discounted returns, and they can help reduce the time it takes for the network to update to the newly</span>
|
||||
<span class="sd"> seen values, since it is not based on bootstrapping the current network values.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">pal_alpha</span> <span class="o">=</span> <span class="mf">0.9</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">persistent_advantage_learning</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">monte_carlo_mixing_rate</span> <span class="o">=</span> <span class="mf">0.1</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">PALAgentParameters</span><span class="p">(</span><span class="n">DQNAgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">algorithm</span> <span class="o">=</span> <span class="n">PALAlgorithmParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="o">=</span> <span class="n">EpisodicExperienceReplayParameters</span><span class="p">()</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.pal_agent:PALAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Persistent Advantage Learning - https://arxiv.org/pdf/1512.04860.pdf</span>
|
||||
<span class="k">class</span> <span class="nc">PALAgent</span><span class="p">(</span><span class="n">ValueOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">pal_alpha</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">persistent</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">persistent_advantage_learning</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">monte_carlo_mixing_rate</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">monte_carlo_mixing_rate</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># next state values</span>
|
||||
<span class="n">q_st_plus_1_target</span><span class="p">,</span> <span class="n">q_st_plus_1_online</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">parallel_prediction</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
<span class="n">selected_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">q_st_plus_1_online</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">v_st_plus_1_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">q_st_plus_1_target</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># current state values</span>
|
||||
<span class="n">q_st_target</span><span class="p">,</span> <span class="n">q_st_online</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">parallel_prediction</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
<span class="n">v_st_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">q_st_target</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate TD error</span>
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">q_st_online</span><span class="p">)</span>
|
||||
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> \
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> \
|
||||
<span class="n">q_st_plus_1_target</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">selected_actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
|
||||
<span class="n">advantage_learning_update</span> <span class="o">=</span> <span class="n">v_st_target</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="n">q_st_target</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span>
|
||||
<span class="n">next_advantage_learning_update</span> <span class="o">=</span> <span class="n">v_st_plus_1_target</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="n">q_st_plus_1_target</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">selected_actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
|
||||
<span class="c1"># Persistent Advantage Learning or Regular Advantage Learning</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">persistent</span><span class="p">:</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">-=</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">*</span> <span class="nb">min</span><span class="p">(</span><span class="n">advantage_learning_update</span><span class="p">,</span> <span class="n">next_advantage_learning_update</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">-=</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">*</span> <span class="n">advantage_learning_update</span>
|
||||
|
||||
<span class="c1"># mixing monte carlo updates</span>
|
||||
<span class="n">monte_carlo_target</span> <span class="o">=</span> <span class="n">total_returns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">monte_carlo_mixing_rate</span><span class="p">)</span> <span class="o">*</span> <span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> \
|
||||
<span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">monte_carlo_mixing_rate</span> <span class="o">*</span> <span class="n">monte_carlo_target</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">TD_targets</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
356
docs/_modules/rl_coach/agents/policy_gradients_agent.html
Normal file
356
docs/_modules/rl_coach/agents/policy_gradients_agent.html
Normal file
@@ -0,0 +1,356 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.policy_gradients_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.policy_gradients_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.policy_gradients_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.policy_optimization_agent</span> <span class="k">import</span> <span class="n">PolicyOptimizationAgent</span><span class="p">,</span> <span class="n">PolicyGradientRescaler</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">PolicyHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> \
|
||||
<span class="n">AgentParameters</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.additive_noise</span> <span class="k">import</span> <span class="n">AdditiveNoiseParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.categorical</span> <span class="k">import</span> <span class="n">CategoricalParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.single_episode_buffer</span> <span class="k">import</span> <span class="n">SingleEpisodeBufferParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">PolicyGradientNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">PolicyHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="PolicyGradientAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/pg.html#rl_coach.agents.policy_gradients_agent.PolicyGradientAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">PolicyGradientAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param policy_gradient_rescaler: (PolicyGradientRescaler)</span>
|
||||
<span class="sd"> The rescaler type to use for the policy gradient loss. For policy gradients, we calculate log probability of</span>
|
||||
<span class="sd"> the action and then multiply it by the policy gradient rescaler. The most basic rescaler is the discounter</span>
|
||||
<span class="sd"> return, but there are other rescalers that are intended for reducing the variance of the updates.</span>
|
||||
|
||||
<span class="sd"> :param apply_gradients_every_x_episodes: (int)</span>
|
||||
<span class="sd"> The number of episodes between applying the accumulated gradients to the network. After every</span>
|
||||
<span class="sd"> num_steps_between_gradient_updates steps, the agent will calculate the gradients for the collected data,</span>
|
||||
<span class="sd"> it will then accumulate it in internal accumulators, and will only apply them to the network once in every</span>
|
||||
<span class="sd"> apply_gradients_every_x_episodes episodes.</span>
|
||||
|
||||
<span class="sd"> :param beta_entropy: (float)</span>
|
||||
<span class="sd"> A factor which defines the amount of entropy regularization to apply to the network. The entropy of the actions</span>
|
||||
<span class="sd"> will be added to the loss and scaled by the given beta factor.</span>
|
||||
|
||||
<span class="sd"> :param num_steps_between_gradient_updates: (int)</span>
|
||||
<span class="sd"> The number of steps between calculating gradients for the collected data. In the A3C paper, this parameter is</span>
|
||||
<span class="sd"> called t_max. Since this algorithm is on-policy, only the steps collected between each two gradient calculations</span>
|
||||
<span class="sd"> are used in the batch.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">=</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">FUTURE_RETURN_NORMALIZED_BY_TIMESTEP</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_every_x_episodes</span> <span class="o">=</span> <span class="mi">5</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">beta_entropy</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_gradient_updates</span> <span class="o">=</span> <span class="mi">20000</span> <span class="c1"># this is called t_max in all the papers</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">PolicyGradientsAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">PolicyGradientAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="p">{</span><span class="n">DiscreteActionSpace</span><span class="p">:</span> <span class="n">CategoricalParameters</span><span class="p">(),</span>
|
||||
<span class="n">BoxActionSpace</span><span class="p">:</span> <span class="n">AdditiveNoiseParameters</span><span class="p">()},</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">SingleEpisodeBufferParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">PolicyGradientNetworkParameters</span><span class="p">()})</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.policy_gradients_agent:PolicyGradientsAgent'</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">PolicyGradientsAgent</span><span class="p">(</span><span class="n">PolicyOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">returns_mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Returns Mean'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">returns_variance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Returns Variance'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_gradient_update_step_idx</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="c1"># batch contains a list of episodes to learn from</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">)):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">TOTAL_RETURN</span><span class="p">:</span>
|
||||
<span class="n">total_returns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">total_returns</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">FUTURE_RETURN</span><span class="p">:</span>
|
||||
<span class="c1"># just take the total return as it is</span>
|
||||
<span class="k">pass</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">FUTURE_RETURN_NORMALIZED_BY_EPISODE</span><span class="p">:</span>
|
||||
<span class="c1"># we can get a single transition episode while playing Doom Basic, causing the std to be 0</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">std_discounted_return</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">total_returns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">total_returns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">mean_discounted_return</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">std_discounted_return</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">total_returns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">FUTURE_RETURN_NORMALIZED_BY_TIMESTEP</span><span class="p">:</span>
|
||||
<span class="n">total_returns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mean_return_over_multiple_episodes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"WARNING: The requested policy gradient rescaler is not available"</span><span class="p">)</span>
|
||||
|
||||
<span class="n">targets</span> <span class="o">=</span> <span class="n">total_returns</span>
|
||||
<span class="n">actions</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">)</span> <span class="o">!=</span> <span class="n">DiscreteActionSpace</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">actions</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o"><</span> <span class="mi">2</span><span class="p">:</span>
|
||||
<span class="n">actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">actions</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">returns_mean</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">total_returns</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">returns_variance</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">total_returns</span><span class="p">))</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">(</span>
|
||||
<span class="p">{</span><span class="o">**</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="s1">'output_0_0'</span><span class="p">:</span> <span class="n">actions</span><span class="p">},</span> <span class="n">targets</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
620
docs/_modules/rl_coach/agents/ppo_agent.html
Normal file
620
docs/_modules/rl_coach/agents/ppo_agent.html
Normal file
@@ -0,0 +1,620 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.ppo_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.ppo_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.ppo_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">copy</span>
|
||||
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.actor_critic_agent</span> <span class="k">import</span> <span class="n">ActorCriticAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.policy_optimization_agent</span> <span class="k">import</span> <span class="n">PolicyGradientRescaler</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">PPOHeadParameters</span><span class="p">,</span> <span class="n">VHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> \
|
||||
<span class="n">AgentParameters</span><span class="p">,</span> <span class="n">DistributedTaskParameters</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">Batch</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.additive_noise</span> <span class="k">import</span> <span class="n">AdditiveNoiseParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.categorical</span> <span class="k">import</span> <span class="n">CategoricalParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">force_list</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">PPOCriticNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'tanh'</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'tanh'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">VHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">l2_regularization</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">PPOActorNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'tanh'</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'tanh'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">PPOHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">l2_regularization</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="PPOAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/ppo.html#rl_coach.agents.ppo_agent.PPOAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">PPOAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param policy_gradient_rescaler: (PolicyGradientRescaler)</span>
|
||||
<span class="sd"> This represents how the critic will be used to update the actor. The critic value function is typically used</span>
|
||||
<span class="sd"> to rescale the gradients calculated by the actor. There are several ways for doing this, such as using the</span>
|
||||
<span class="sd"> advantage of the action, or the generalized advantage estimation (GAE) value.</span>
|
||||
|
||||
<span class="sd"> :param gae_lambda: (float)</span>
|
||||
<span class="sd"> The :math:`\lambda` value is used within the GAE function in order to weight different bootstrap length</span>
|
||||
<span class="sd"> estimations. Typical values are in the range 0.9-1, and define an exponential decay over the different</span>
|
||||
<span class="sd"> n-step estimations.</span>
|
||||
|
||||
<span class="sd"> :param target_kl_divergence: (float)</span>
|
||||
<span class="sd"> The target kl divergence between the current policy distribution and the new policy. PPO uses a heuristic to</span>
|
||||
<span class="sd"> bring the KL divergence to this value, by adding a penalty if the kl divergence is higher.</span>
|
||||
|
||||
<span class="sd"> :param initial_kl_coefficient: (float)</span>
|
||||
<span class="sd"> The initial weight that will be given to the KL divergence between the current and the new policy in the</span>
|
||||
<span class="sd"> regularization factor.</span>
|
||||
|
||||
<span class="sd"> :param high_kl_penalty_coefficient: (float)</span>
|
||||
<span class="sd"> The penalty that will be given for KL divergence values which are highes than what was defined as the target.</span>
|
||||
|
||||
<span class="sd"> :param clip_likelihood_ratio_using_epsilon: (float)</span>
|
||||
<span class="sd"> If not None, the likelihood ratio between the current and new policy in the PPO loss function will be</span>
|
||||
<span class="sd"> clipped to the range [1-clip_likelihood_ratio_using_epsilon, 1+clip_likelihood_ratio_using_epsilon].</span>
|
||||
<span class="sd"> This is typically used in the Clipped PPO version of PPO, and should be set to None in regular PPO</span>
|
||||
<span class="sd"> implementations.</span>
|
||||
|
||||
<span class="sd"> :param value_targets_mix_fraction: (float)</span>
|
||||
<span class="sd"> The targets for the value network are an exponential weighted moving average which uses this mix fraction to</span>
|
||||
<span class="sd"> define how much of the new targets will be taken into account when calculating the loss.</span>
|
||||
<span class="sd"> This value should be set to the range (0,1], where 1 means that only the new targets will be taken into account.</span>
|
||||
|
||||
<span class="sd"> :param estimate_state_value_using_gae: (bool)</span>
|
||||
<span class="sd"> If set to True, the state value will be estimated using the GAE technique.</span>
|
||||
|
||||
<span class="sd"> :param use_kl_regularization: (bool)</span>
|
||||
<span class="sd"> If set to True, the loss function will be regularized using the KL diveregence between the current and new</span>
|
||||
<span class="sd"> policy, to bound the change of the policy during the network update.</span>
|
||||
|
||||
<span class="sd"> :param beta_entropy: (float)</span>
|
||||
<span class="sd"> An entropy regulaization term can be added to the loss function in order to control exploration. This term</span>
|
||||
<span class="sd"> is weighted using the :math:`\beta` value defined by beta_entropy.</span>
|
||||
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">=</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">GAE</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">gae_lambda</span> <span class="o">=</span> <span class="mf">0.96</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">target_kl_divergence</span> <span class="o">=</span> <span class="mf">0.01</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">initial_kl_coefficient</span> <span class="o">=</span> <span class="mf">1.0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">high_kl_penalty_coefficient</span> <span class="o">=</span> <span class="mi">1000</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">clip_likelihood_ratio_using_epsilon</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_targets_mix_fraction</span> <span class="o">=</span> <span class="mf">0.1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">estimate_state_value_using_gae</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_kl_regularization</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">beta_entropy</span> <span class="o">=</span> <span class="mf">0.01</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">5000</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">PPOAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">PPOAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="p">{</span><span class="n">DiscreteActionSpace</span><span class="p">:</span> <span class="n">CategoricalParameters</span><span class="p">(),</span>
|
||||
<span class="n">BoxActionSpace</span><span class="p">:</span> <span class="n">AdditiveNoiseParameters</span><span class="p">()},</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="p">{</span><span class="s2">"critic"</span><span class="p">:</span> <span class="n">PPOCriticNetworkParameters</span><span class="p">(),</span> <span class="s2">"actor"</span><span class="p">:</span> <span class="n">PPOActorNetworkParameters</span><span class="p">()})</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.ppo_agent:PPOAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Proximal Policy Optimization - https://arxiv.org/pdf/1707.06347.pdf</span>
|
||||
<span class="k">class</span> <span class="nc">PPOAgent</span><span class="p">(</span><span class="n">ActorCriticAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># signals definition</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Value Loss'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Policy Loss'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">kl_divergence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'KL Divergence'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_kl_divergence_during_training_process</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">unclipped_grads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Grads (unclipped)'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">fill_advantages</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># * Found not to have any impact *</span>
|
||||
<span class="c1"># current_states_with_timestep = self.concat_state_and_timestep(batch)</span>
|
||||
|
||||
<span class="n">current_state_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
<span class="c1"># calculate advantages</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">A_VALUE</span><span class="p">:</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="n">total_returns</span> <span class="o">-</span> <span class="n">current_state_values</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">policy_gradient_rescaler</span> <span class="o">==</span> <span class="n">PolicyGradientRescaler</span><span class="o">.</span><span class="n">GAE</span><span class="p">:</span>
|
||||
<span class="c1"># get bootstraps</span>
|
||||
<span class="n">episode_start_idx</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([])</span>
|
||||
<span class="c1"># current_state_values[batch.game_overs()] = 0</span>
|
||||
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">game_over</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()):</span>
|
||||
<span class="k">if</span> <span class="n">game_over</span><span class="p">:</span>
|
||||
<span class="c1"># get advantages for the rollout</span>
|
||||
<span class="n">value_bootstrapping</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,))</span>
|
||||
<span class="n">rollout_state_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">current_state_values</span><span class="p">[</span><span class="n">episode_start_idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">value_bootstrapping</span><span class="p">)</span>
|
||||
|
||||
<span class="n">rollout_advantages</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">get_general_advantage_estimation_values</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">episode_start_idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
|
||||
<span class="n">rollout_state_values</span><span class="p">)</span>
|
||||
<span class="n">episode_start_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">advantages</span><span class="p">,</span> <span class="n">rollout_advantages</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"WARNING: The requested policy gradient rescaler is not available"</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># standardize</span>
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="p">(</span><span class="n">advantages</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">advantages</span><span class="p">))</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">advantages</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># TODO: this will be problematic with a shared memory</span>
|
||||
<span class="k">for</span> <span class="n">transition</span><span class="p">,</span> <span class="n">advantage</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">transitions</span><span class="p">,</span> <span class="n">advantages</span><span class="p">):</span>
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">'advantage'</span><span class="p">]</span> <span class="o">=</span> <span class="n">advantage</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_advantages</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">advantages</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train_value_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="p">):</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># * Found not to have any impact *</span>
|
||||
<span class="c1"># add a timestep to the observation</span>
|
||||
<span class="c1"># current_states_with_timestep = self.concat_state_and_timestep(dataset)</span>
|
||||
|
||||
<span class="n">mix_fraction</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">value_targets_mix_fraction</span>
|
||||
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
|
||||
<span class="n">curr_batch_size</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">size</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">!=</span> <span class="s1">'LBFGS'</span><span class="p">:</span>
|
||||
<span class="n">curr_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span> <span class="o">//</span> <span class="n">curr_batch_size</span><span class="p">):</span>
|
||||
<span class="c1"># split to batches for first order optimization techniques</span>
|
||||
<span class="n">current_states_batch</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">curr_batch_size</span><span class="p">:(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">curr_batch_size</span><span class="p">]</span>
|
||||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
||||
<span class="p">}</span>
|
||||
<span class="n">total_return_batch</span> <span class="o">=</span> <span class="n">total_returns</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">curr_batch_size</span><span class="p">:(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">curr_batch_size</span><span class="p">]</span>
|
||||
<span class="n">old_policy_values</span> <span class="o">=</span> <span class="n">force_list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="n">current_states_batch</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">!=</span> <span class="s1">'LBFGS'</span><span class="p">:</span>
|
||||
<span class="n">targets</span> <span class="o">=</span> <span class="n">total_return_batch</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">current_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">current_states_batch</span><span class="p">)</span>
|
||||
<span class="n">targets</span> <span class="o">=</span> <span class="n">current_values</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">mix_fraction</span><span class="p">)</span> <span class="o">+</span> <span class="n">total_return_batch</span> <span class="o">*</span> <span class="n">mix_fraction</span>
|
||||
|
||||
<span class="n">inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">current_states_batch</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">input_index</span><span class="p">,</span> <span class="nb">input</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">old_policy_values</span><span class="p">):</span>
|
||||
<span class="n">name</span> <span class="o">=</span> <span class="s1">'output_0_</span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">input_index</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">inputs</span><span class="p">:</span>
|
||||
<span class="n">inputs</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="nb">input</span>
|
||||
|
||||
<span class="n">value_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">apply_gradients_to_online_network</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="p">,</span> <span class="n">DistributedTaskParameters</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">reset_accumulated_gradients</span><span class="p">()</span>
|
||||
|
||||
<span class="n">loss</span><span class="o">.</span><span class="n">append</span><span class="p">([</span><span class="n">value_loss</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">loss</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">concat_state_and_timestep</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">):</span>
|
||||
<span class="n">current_states_with_timestep</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">transition</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="s1">'observation'</span><span class="p">],</span> <span class="n">transition</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">'timestep'</span><span class="p">])</span>
|
||||
<span class="k">for</span> <span class="n">transition</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">]</span>
|
||||
<span class="n">current_states_with_timestep</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">current_states_with_timestep</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">current_states_with_timestep</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train_policy_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="p">):</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="s1">'total_loss'</span><span class="p">:</span> <span class="p">[],</span>
|
||||
<span class="s1">'policy_losses'</span><span class="p">:</span> <span class="p">[],</span>
|
||||
<span class="s1">'unclipped_grads'</span><span class="p">:</span> <span class="p">[],</span>
|
||||
<span class="s1">'fetch_result'</span><span class="p">:</span> <span class="p">[]</span>
|
||||
<span class="p">}</span>
|
||||
<span class="c1">#shuffle(dataset)</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">:</span>
|
||||
<span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">])</span>
|
||||
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="n">advantages</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'advantage'</span><span class="p">)</span>
|
||||
<span class="n">actions</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">actions</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">actions</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get old policy probabilities and distribution</span>
|
||||
<span class="n">old_policy</span> <span class="o">=</span> <span class="n">force_list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)))</span>
|
||||
|
||||
<span class="c1"># calculate gradients and apply on both the local policy network and on the global policy network</span>
|
||||
<span class="n">fetches</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">kl_divergence</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">entropy</span><span class="p">]</span>
|
||||
|
||||
<span class="n">inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="n">inputs</span><span class="p">[</span><span class="s1">'output_0_0'</span><span class="p">]</span> <span class="o">=</span> <span class="n">actions</span>
|
||||
|
||||
<span class="c1"># old_policy_distribution needs to be represented as a list, because in the event of discrete controls,</span>
|
||||
<span class="c1"># it has just a mean. otherwise, it has both a mean and standard deviation</span>
|
||||
<span class="k">for</span> <span class="n">input_index</span><span class="p">,</span> <span class="nb">input</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">old_policy</span><span class="p">):</span>
|
||||
<span class="n">inputs</span><span class="p">[</span><span class="s1">'output_0_</span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">input_index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="nb">input</span>
|
||||
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">policy_losses</span><span class="p">,</span> <span class="n">unclipped_grads</span><span class="p">,</span> <span class="n">fetch_result</span> <span class="o">=</span>\
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">(</span>
|
||||
<span class="n">inputs</span><span class="p">,</span> <span class="p">[</span><span class="n">advantages</span><span class="p">],</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="n">fetches</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">apply_gradients_to_online_network</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="p">,</span> <span class="n">DistributedTaskParameters</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">reset_accumulated_gradients</span><span class="p">()</span>
|
||||
|
||||
<span class="n">loss</span><span class="p">[</span><span class="s1">'total_loss'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span>
|
||||
<span class="n">loss</span><span class="p">[</span><span class="s1">'policy_losses'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">policy_losses</span><span class="p">)</span>
|
||||
<span class="n">loss</span><span class="p">[</span><span class="s1">'unclipped_grads'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">unclipped_grads</span><span class="p">)</span>
|
||||
<span class="n">loss</span><span class="p">[</span><span class="s1">'fetch_result'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">fetch_result</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">unclipped_grads</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">unclipped_grads</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">loss</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="n">loss</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">loss</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">learning_rate_decay_rate</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">curr_learning_rate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">get_variable_value</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">curr_learning_rate</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">curr_learning_rate</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">curr_learning_rate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">learning_rate</span>
|
||||
|
||||
<span class="c1"># log training parameters</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_dict</span><span class="p">(</span>
|
||||
<span class="n">OrderedDict</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="s2">"Surrogate loss"</span><span class="p">,</span> <span class="n">loss</span><span class="p">[</span><span class="s1">'policy_losses'</span><span class="p">][</span><span class="mi">0</span><span class="p">]),</span>
|
||||
<span class="p">(</span><span class="s2">"KL divergence"</span><span class="p">,</span> <span class="n">loss</span><span class="p">[</span><span class="s1">'fetch_result'</span><span class="p">][</span><span class="mi">0</span><span class="p">]),</span>
|
||||
<span class="p">(</span><span class="s2">"Entropy"</span><span class="p">,</span> <span class="n">loss</span><span class="p">[</span><span class="s1">'fetch_result'</span><span class="p">][</span><span class="mi">1</span><span class="p">]),</span>
|
||||
<span class="p">(</span><span class="s2">"training epoch"</span><span class="p">,</span> <span class="n">j</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="s2">"learning_rate"</span><span class="p">,</span> <span class="n">curr_learning_rate</span><span class="p">)</span>
|
||||
<span class="p">]),</span>
|
||||
<span class="n">prefix</span><span class="o">=</span><span class="s2">"Policy training"</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_kl_divergence_during_training_process</span> <span class="o">=</span> <span class="n">loss</span><span class="p">[</span><span class="s1">'fetch_result'</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">entropy</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">loss</span><span class="p">[</span><span class="s1">'fetch_result'</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">kl_divergence</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">loss</span><span class="p">[</span><span class="s1">'fetch_result'</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="k">return</span> <span class="n">loss</span><span class="p">[</span><span class="s1">'total_loss'</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">update_kl_coefficient</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="c1"># John Schulman takes the mean kl divergence only over the last epoch which is strange but we will follow</span>
|
||||
<span class="c1"># his implementation for now because we know it works well</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"KL = </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">total_kl_divergence_during_training_process</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># update kl coefficient</span>
|
||||
<span class="n">kl_target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">target_kl_divergence</span>
|
||||
<span class="n">kl_coefficient</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">get_variable_value</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">kl_coefficient</span><span class="p">)</span>
|
||||
<span class="n">new_kl_coefficient</span> <span class="o">=</span> <span class="n">kl_coefficient</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_kl_divergence_during_training_process</span> <span class="o">></span> <span class="mf">1.3</span> <span class="o">*</span> <span class="n">kl_target</span><span class="p">:</span>
|
||||
<span class="c1"># kl too high => increase regularization</span>
|
||||
<span class="n">new_kl_coefficient</span> <span class="o">*=</span> <span class="mf">1.5</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_kl_divergence_during_training_process</span> <span class="o"><</span> <span class="mf">0.7</span> <span class="o">*</span> <span class="n">kl_target</span><span class="p">:</span>
|
||||
<span class="c1"># kl too low => decrease regularization</span>
|
||||
<span class="n">new_kl_coefficient</span> <span class="o">/=</span> <span class="mf">1.5</span>
|
||||
|
||||
<span class="c1"># update the kl coefficient variable</span>
|
||||
<span class="k">if</span> <span class="n">kl_coefficient</span> <span class="o">!=</span> <span class="n">new_kl_coefficient</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">set_variable_value</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">assign_kl_coefficient</span><span class="p">,</span>
|
||||
<span class="n">new_kl_coefficient</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">kl_coefficient_ph</span><span class="p">)</span>
|
||||
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"KL penalty coefficient change = </span><span class="si">{}</span><span class="s2"> -> </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">kl_coefficient</span><span class="p">,</span> <span class="n">new_kl_coefficient</span><span class="p">))</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">post_training_commands</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">use_kl_regularization</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_kl_coefficient</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># clean memory</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'clean'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_should_train_helper</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">wait_for_full_episode</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_should_train_helper</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_train</span><span class="p">(</span><span class="n">wait_for_full_episode</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">training_step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">sync</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">sync</span><span class="p">()</span>
|
||||
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">transitions</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">fill_advantages</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># take only the requested number of steps</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span><span class="o">.</span><span class="n">num_steps</span><span class="p">]</span>
|
||||
|
||||
<span class="n">value_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_value_network</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">policy_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_policy_network</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">value_loss</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_loss</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">policy_loss</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">post_training_commands</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_log</span><span class="p">()</span> <span class="c1"># should be done in order to update the data that has been accumulated * while not playing *</span>
|
||||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">value_loss</span><span class="p">,</span> <span class="n">policy_loss</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
|
||||
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s2">"actor"</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
347
docs/_modules/rl_coach/agents/qr_dqn_agent.html
Normal file
347
docs/_modules/rl_coach/agents/qr_dqn_agent.html
Normal file
@@ -0,0 +1,347 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.qr_dqn_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.qr_dqn_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.qr_dqn_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.dqn_agent</span> <span class="k">import</span> <span class="n">DQNAgentParameters</span><span class="p">,</span> <span class="n">DQNNetworkParameters</span><span class="p">,</span> <span class="n">DQNAlgorithmParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.value_optimization_agent</span> <span class="k">import</span> <span class="n">ValueOptimizationAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">QuantileRegressionQHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">StateType</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.schedules</span> <span class="k">import</span> <span class="n">LinearSchedule</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">QuantileRegressionDQNNetworkParameters</span><span class="p">(</span><span class="n">DQNNetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">QuantileRegressionQHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.00005</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_epsilon</span> <span class="o">=</span> <span class="mf">0.01</span> <span class="o">/</span> <span class="mi">32</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="QuantileRegressionDQNAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/qr_dqn.html#rl_coach.agents.qr_dqn_agent.QuantileRegressionDQNAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">QuantileRegressionDQNAlgorithmParameters</span><span class="p">(</span><span class="n">DQNAlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param atoms: (int)</span>
|
||||
<span class="sd"> the number of atoms to predict for each action</span>
|
||||
|
||||
<span class="sd"> :param huber_loss_interval: (float)</span>
|
||||
<span class="sd"> One of the huber loss parameters, and is referred to as :math:`\kapa` in the paper.</span>
|
||||
<span class="sd"> It describes the interval [-k, k] in which the huber loss acts as a MSE loss.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">atoms</span> <span class="o">=</span> <span class="mi">200</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">huber_loss_interval</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># called k in the paper</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">QuantileRegressionDQNAgentParameters</span><span class="p">(</span><span class="n">DQNAgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">algorithm</span> <span class="o">=</span> <span class="n">QuantileRegressionDQNAlgorithmParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">QuantileRegressionDQNNetworkParameters</span><span class="p">()}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span><span class="o">.</span><span class="n">epsilon_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span><span class="o">.</span><span class="n">evaluation_epsilon</span> <span class="o">=</span> <span class="mf">0.001</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.qr_dqn_agent:QuantileRegressionDQNAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Quantile Regression Deep Q Network - https://arxiv.org/pdf/1710.10044v1.pdf</span>
|
||||
<span class="k">class</span> <span class="nc">QuantileRegressionDQNAgent</span><span class="p">(</span><span class="n">ValueOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">quantile_probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span><span class="p">)</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_q_values</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">quantile_values</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">quantile_values</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">quantile_probabilities</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># prediction's format is (batch,actions,atoms)</span>
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
|
||||
<span class="n">quantile_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_q_values</span><span class="p">(</span><span class="n">quantile_values</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">return</span> <span class="n">actions_q_values</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># get the quantiles of the next states and current states</span>
|
||||
<span class="n">next_state_quantiles</span><span class="p">,</span> <span class="n">current_quantiles</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">parallel_prediction</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="c1"># get the optimal actions to take for the next states</span>
|
||||
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_q_values</span><span class="p">(</span><span class="n">next_state_quantiles</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate the Bellman update</span>
|
||||
<span class="n">batch_idx</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">))</span>
|
||||
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">(</span><span class="kc">True</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> \
|
||||
<span class="o">*</span> <span class="n">next_state_quantiles</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">target_actions</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># get the locations of the selected actions within the batch for indexing purposes</span>
|
||||
<span class="n">actions_locations</span> <span class="o">=</span> <span class="p">[[</span><span class="n">b</span><span class="p">,</span> <span class="n">a</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span><span class="p">,</span> <span class="n">a</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">())]</span>
|
||||
|
||||
<span class="c1"># calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order</span>
|
||||
<span class="n">cumulative_probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span><span class="p">)</span> <span class="c1"># tau_i</span>
|
||||
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="mf">0.5</span><span class="o">*</span><span class="p">(</span><span class="n">cumulative_probabilities</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="n">cumulative_probabilities</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># tau^hat_i</span>
|
||||
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">quantile_midpoints</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="n">sorted_quantiles</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">current_quantiles</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()])</span>
|
||||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="n">quantile_midpoints</span><span class="p">[</span><span class="n">idx</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">quantile_midpoints</span><span class="p">[</span><span class="n">idx</span><span class="p">,</span> <span class="n">sorted_quantiles</span><span class="p">[</span><span class="n">idx</span><span class="p">]]</span>
|
||||
|
||||
<span class="c1"># train</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">({</span>
|
||||
<span class="o">**</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span>
|
||||
<span class="s1">'output_0_0'</span><span class="p">:</span> <span class="n">actions_locations</span><span class="p">,</span>
|
||||
<span class="s1">'output_0_1'</span><span class="p">:</span> <span class="n">quantile_midpoints</span><span class="p">,</span>
|
||||
<span class="p">},</span> <span class="n">TD_targets</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
359
docs/_modules/rl_coach/agents/rainbow_dqn_agent.html
Normal file
359
docs/_modules/rl_coach/agents/rainbow_dqn_agent.html
Normal file
@@ -0,0 +1,359 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.rainbow_dqn_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.rainbow_dqn_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.rainbow_dqn_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.categorical_dqn_agent</span> <span class="k">import</span> <span class="n">CategoricalDQNAlgorithmParameters</span><span class="p">,</span> \
|
||||
<span class="n">CategoricalDQNAgent</span><span class="p">,</span> <span class="n">CategoricalDQNAgentParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.dqn_agent</span> <span class="k">import</span> <span class="n">DQNNetworkParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">RainbowQHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">MiddlewareScheme</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.parameter_noise</span> <span class="k">import</span> <span class="n">ParameterNoiseParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.prioritized_experience_replay</span> <span class="k">import</span> <span class="n">PrioritizedExperienceReplayParameters</span><span class="p">,</span> \
|
||||
<span class="n">PrioritizedExperienceReplay</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">RainbowDQNNetworkParameters</span><span class="p">(</span><span class="n">DQNNetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">RainbowQHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">MiddlewareScheme</span><span class="o">.</span><span class="n">Empty</span><span class="p">)</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="RainbowDQNAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/rainbow.html#rl_coach.agents.rainbow_dqn_agent.RainbowDQNAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">RainbowDQNAlgorithmParameters</span><span class="p">(</span><span class="n">CategoricalDQNAlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param n_step: (int)</span>
|
||||
<span class="sd"> The number of steps to bootstrap the network over. The first N-1 steps actual rewards will be accumulated</span>
|
||||
<span class="sd"> using an exponentially growing discount factor, and the Nth step will be bootstrapped from the network</span>
|
||||
<span class="sd"> prediction.</span>
|
||||
|
||||
<span class="sd"> :param store_transitions_only_when_episodes_are_terminated: (bool)</span>
|
||||
<span class="sd"> If set to True, the transitions will be stored in an Episode object until the episode ends, and just then</span>
|
||||
<span class="sd"> written to the memory. This is useful since we want to calculate the N-step discounted rewards before saving the</span>
|
||||
<span class="sd"> transitions into the memory, and to do so we need the entire episode first.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">n_step</span> <span class="o">=</span> <span class="mi">3</span>
|
||||
|
||||
<span class="c1"># needed for n-step updates to work. i.e. waiting for a full episode to be closed before storing each transition</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">store_transitions_only_when_episodes_are_terminated</span> <span class="o">=</span> <span class="kc">True</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">RainbowDQNAgentParameters</span><span class="p">(</span><span class="n">CategoricalDQNAgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">algorithm</span> <span class="o">=</span> <span class="n">RainbowDQNAlgorithmParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span> <span class="o">=</span> <span class="n">ParameterNoiseParameters</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="o">=</span> <span class="n">PrioritizedExperienceReplayParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">RainbowDQNNetworkParameters</span><span class="p">()}</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.rainbow_dqn_agent:RainbowDQNAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Rainbow Deep Q Network - https://arxiv.org/abs/1710.02298</span>
|
||||
<span class="c1"># Agent implementation is composed of:</span>
|
||||
<span class="c1"># 1. NoisyNets</span>
|
||||
<span class="c1"># 2. C51</span>
|
||||
<span class="c1"># 3. Prioritized ER</span>
|
||||
<span class="c1"># 4. DDQN</span>
|
||||
<span class="c1"># 5. Dueling DQN</span>
|
||||
<span class="c1"># 6. N-step returns</span>
|
||||
|
||||
<span class="k">class</span> <span class="nc">RainbowDQNAgent</span><span class="p">(</span><span class="n">CategoricalDQNAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="n">ddqn_selected_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># for the action we actually took, the error is calculated by the atoms distribution</span>
|
||||
<span class="c1"># for all other actions, the error is 0</span>
|
||||
<span class="n">distributional_q_st_plus_n</span><span class="p">,</span> <span class="n">TD_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">parallel_prediction</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">target_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="c1"># only update the action that we have actually done in this transition (using the Double-DQN selected actions)</span>
|
||||
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">ddqn_selected_actions</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||||
|
||||
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="c1"># we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step,</span>
|
||||
<span class="c1"># we will not bootstrap for the last n-step transitions in the episode</span>
|
||||
<span class="n">tzj</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">fmax</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">fmin</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span> <span class="o">+</span> <span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'should_bootstrap_next_state'</span><span class="p">)</span> <span class="o">*</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">**</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">n_step</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="n">j</span><span class="p">],</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="n">bj</span> <span class="o">=</span> <span class="p">(</span><span class="n">tzj</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">/</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="n">u</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">bj</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
|
||||
<span class="n">l</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">bj</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
|
||||
<span class="n">m</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">l</span><span class="p">]</span> <span class="o">+=</span> <span class="p">(</span><span class="n">distributional_q_st_plus_n</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">target_actions</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">u</span> <span class="o">-</span> <span class="n">bj</span><span class="p">))</span>
|
||||
<span class="n">m</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">u</span><span class="p">]</span> <span class="o">+=</span> <span class="p">(</span><span class="n">distributional_q_st_plus_n</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">target_actions</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">bj</span> <span class="o">-</span> <span class="n">l</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># total_loss = cross entropy between actual result above and predicted result for the given action</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">batches</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()]</span> <span class="o">=</span> <span class="n">m</span>
|
||||
|
||||
<span class="c1"># update errors in prioritized replay buffer</span>
|
||||
<span class="n">importance_weights</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'weight'</span><span class="p">)</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">PrioritizedExperienceReplay</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">TD_targets</span><span class="p">,</span>
|
||||
<span class="n">importance_weights</span><span class="o">=</span><span class="n">importance_weights</span><span class="p">)</span>
|
||||
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># TODO: fix this spaghetti code</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">PrioritizedExperienceReplay</span><span class="p">):</span>
|
||||
<span class="n">errors</span> <span class="o">=</span> <span class="n">losses</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'update_priorities'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'idx'</span><span class="p">),</span> <span class="n">errors</span><span class="p">))</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
325
docs/_modules/rl_coach/agents/value_optimization_agent.html
Normal file
325
docs/_modules/rl_coach/agents/value_optimization_agent.html
Normal file
@@ -0,0 +1,325 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.value_optimization_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.value_optimization_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.value_optimization_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">StateType</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.prioritized_experience_replay</span> <span class="k">import</span> <span class="n">PrioritizedExperienceReplay</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span>
|
||||
|
||||
|
||||
<span class="c1">## This is an abstract agent - there is no learn_from_batch method ##</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">ValueOptimizationAgent</span><span class="p">(</span><span class="n">Agent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"Q"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">init_environment_dependent_modules</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_environment_dependent_modules</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">actions</span><span class="p">)):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"Q for action </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">i</span><span class="p">),</span>
|
||||
<span class="n">dump_one_value_per_episode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">dump_one_value_per_step</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Algorithms for which q_values are calculated from predictions will override this function</span>
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">return</span> <span class="n">actions_q_values</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">))</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">update_transition_priorities_and_get_weights</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">TD_errors</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="c1"># update errors in prioritized replay buffer</span>
|
||||
<span class="n">importance_weights</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">PrioritizedExperienceReplay</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'update_priorities'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'idx'</span><span class="p">),</span> <span class="n">TD_errors</span><span class="p">))</span>
|
||||
<span class="n">importance_weights</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'weight'</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">importance_weights</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_validate_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">policy</span><span class="p">,</span> <span class="n">action</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">action</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="p">():</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">((</span>
|
||||
<span class="s1">'The exploration_policy </span><span class="si">{}</span><span class="s1"> returned a vector of actions '</span>
|
||||
<span class="s1">'instead of a single action. ValueOptimizationAgents '</span>
|
||||
<span class="s1">'require exploration policies which return a single action.'</span>
|
||||
<span class="p">)</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">policy</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">))</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_q_values_for_states</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># choose action according to the exploration policy and the current phase (evaluating or training the agent)</span>
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_validate_action</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="p">,</span> <span class="n">action</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">actions_q_values</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="c1"># this is for bootstrapped dqn</span>
|
||||
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span> <span class="o">==</span> <span class="nb">list</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">last_action_values</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># store the q values statistics for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">q_value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q_value</span><span class="p">)</span>
|
||||
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span>
|
||||
<span class="n">action_value</span><span class="o">=</span><span class="n">actions_q_values</span><span class="p">[</span><span class="n">action</span><span class="p">],</span>
|
||||
<span class="n">max_action_value</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">))</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">action_info</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">"ValueOptimizationAgent is an abstract agent. Not to be used directly."</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user