1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)

* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
This commit is contained in:
Gal Leibovich
2019-09-08 12:53:49 +03:00
committed by GitHub
parent fc50398544
commit 138ced23ba
46 changed files with 1193 additions and 51 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 60 KiB

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

View File

@@ -202,6 +202,7 @@
<li><a href="rl_coach/agents/soft_actor_critic_agent.html">rl_coach.agents.soft_actor_critic_agent</a></li>
<li><a href="rl_coach/agents/td3_agent.html">rl_coach.agents.td3_agent</a></li>
<li><a href="rl_coach/agents/value_optimization_agent.html">rl_coach.agents.value_optimization_agent</a></li>
<li><a href="rl_coach/agents/wolpertinger_agent.html">rl_coach.agents.wolpertinger_agent</a></li>
<li><a href="rl_coach/architectures/architecture.html">rl_coach.architectures.architecture</a></li>
<li><a href="rl_coach/architectures/network_wrapper.html">rl_coach.architectures.network_wrapper</a></li>
<li><a href="rl_coach/base_parameters.html">rl_coach.base_parameters</a></li>

View File

@@ -756,6 +756,9 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">EpisodicExperienceReplay</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">override_episode_rewards_with_the_last_transition_reward</span><span class="p">:</span>
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
<span class="n">t</span><span class="o">.</span><span class="n">reward</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;store_episode&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">store_transitions_only_when_episodes_are_terminated</span><span class="p">:</span>
<span class="k">for</span> <span class="n">transition</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
@@ -910,7 +913,8 @@
<span class="c1"># update counters</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># if the batch returned empty then there are not enough samples in the replay buffer -&gt; skip</span>
<span class="c1"># training step</span>
@@ -1020,7 +1024,8 @@
<span class="c1"># informed action</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># before choosing an action, first use the pre_network_filter to filter out the current state</span>
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="n">update_filter_internal_state</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
@@ -1048,6 +1053,10 @@
<span class="sd"> :return: The filtered state</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># TODO actually we only want to run the observation filters. No point in running the reward filters as the</span>
<span class="c1"># filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters&#39; internal</span>
<span class="c1"># state).</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span>
<span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_filter_internal_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span></div>
@@ -1177,7 +1186,7 @@
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent</span>
<span class="sd"> has another master agent that is controlling it. In such cases, the master agent can define the goals for the</span>
<span class="sd"> slave agent, define it&#39;s observation, possible actions, etc. The directive type is defined by the agent</span>
<span class="sd"> slave agent, define its observation, possible actions, etc. The directive type is defined by the agent</span>
<span class="sd"> in-action-space.</span>
<span class="sd"> :param action: The action that should be set as the directive</span>

View File

@@ -295,7 +295,9 @@
<span class="bp">self</span><span class="o">.</span><span class="n">optimization_epochs</span> <span class="o">=</span> <span class="mi">10</span>
<span class="bp">self</span><span class="o">.</span><span class="n">normalization_stats</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clipping_decay_schedule</span> <span class="o">=</span> <span class="n">ConstantSchedule</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">act_for_full_episodes</span> <span class="o">=</span> <span class="kc">True</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">act_for_full_episodes</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="o">=</span> <span class="kc">False</span></div>
<span class="k">class</span> <span class="nc">ClippedPPOAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
@@ -486,7 +488,9 @@
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">transitions</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">)</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="k">for</span> <span class="n">training_step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">):</span>
@@ -512,7 +516,9 @@
<span class="k">def</span> <span class="nf">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span>
<span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clipping_decay_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>

View File

@@ -0,0 +1,356 @@
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>rl_coach.agents.wolpertinger_agent &mdash; Reinforcement Learning Coach 0.12.0 documentation</title>
<script type="text/javascript" src="../../../_static/js/modernizr.min.js"></script>
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Intro</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
</ul>
<p class="caption"><span class="caption-text">Design</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
</ul>
<p class="caption"><span class="caption-text">Contributing</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
</ul>
<p class="caption"><span class="caption-text">Components</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/data_stores/index.html">Data Stores</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memory_backends/index.html">Memory Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/orchestrators/index.html">Orchestrators</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">Reinforcement Learning Coach</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html">Docs</a> &raquo;</li>
<li><a href="../../index.html">Module code</a> &raquo;</li>
<li>rl_coach.agents.wolpertinger_agent</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for rl_coach.agents.wolpertinger_agent</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1"># Copyright (c) 2019 Intel Corporation </span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">import</span> <span class="nn">copy</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.agents.ddpg_agent</span> <span class="k">import</span> <span class="n">DDPGAlgorithmParameters</span><span class="p">,</span> <span class="n">DDPGActorNetworkParameters</span><span class="p">,</span> \
<span class="n">DDPGCriticNetworkParameters</span><span class="p">,</span> <span class="n">DDPGAgent</span>
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AgentParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.additive_noise</span> <span class="k">import</span> <span class="n">AdditiveNoiseParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.differentiable_neural_dictionary</span> <span class="k">import</span> <span class="n">AnnoyDictionary</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">WolpertingerActorHeadParameters</span>
<span class="k">class</span> <span class="nc">WolpertingerCriticNetworkParameters</span><span class="p">(</span><span class="n">DDPGCriticNetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">WolpertingerActorNetworkParameters</span><span class="p">(</span><span class="n">DDPGActorNetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">WolpertingerActorHeadParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)]</span>
<div class="viewcode-block" id="WolpertingerAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/wolpertinger.html#rl_coach.agents.wolpertinger_agent.WolpertingerAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">WolpertingerAlgorithmParameters</span><span class="p">(</span><span class="n">DDPGAlgorithmParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action_embedding_width</span> <span class="o">=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">k</span> <span class="o">=</span> <span class="mi">1</span></div>
<span class="k">class</span> <span class="nc">WolpertingerAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="n">exploration_params</span> <span class="o">=</span> <span class="n">AdditiveNoiseParameters</span><span class="p">()</span>
<span class="n">exploration_params</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span> <span class="o">=</span> <span class="kc">False</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">WolpertingerAlgorithmParameters</span><span class="p">(),</span>
<span class="n">exploration</span><span class="o">=</span><span class="n">exploration_params</span><span class="p">,</span>
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">(</span>
<span class="p">[(</span><span class="s2">&quot;actor&quot;</span><span class="p">,</span> <span class="n">WolpertingerActorNetworkParameters</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)),</span>
<span class="p">(</span><span class="s2">&quot;critic&quot;</span><span class="p">,</span> <span class="n">WolpertingerCriticNetworkParameters</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">))]))</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="s1">&#39;rl_coach.agents.wolpertinger_agent:WolpertingerAgent&#39;</span>
<span class="c1"># Deep Reinforcement Learning in Large Discrete Action Spaces - https://arxiv.org/pdf/1512.07679.pdf</span>
<span class="k">class</span> <span class="nc">WolpertingerAgent</span><span class="p">(</span><span class="n">DDPGAgent</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">&#39;LevelManager&#39;</span><span class="p">,</span> <span class="s1">&#39;CompositeAgent&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="c1"># replay buffer holds the actions in the discrete manner, as the agent is expected to act with discrete actions</span>
<span class="c1"># with the BoxDiscretization output filter. But DDPG needs to work on continuous actions, thus converting to</span>
<span class="c1"># continuous actions. This is actually a duplicate since this filtering is also done before applying actions on</span>
<span class="c1"># the environment. So might want to somehow reuse that conversion. Maybe can hold this information in the info</span>
<span class="c1"># dictionary of the transition.</span>
<span class="n">output_action_filter</span> <span class="o">=</span> \
<span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">output_filter</span><span class="o">.</span><span class="n">action_filters</span><span class="o">.</span><span class="n">values</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">continuous_actions</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">action</span> <span class="ow">in</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">():</span>
<span class="n">continuous_actions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">output_action_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">action</span><span class="p">))</span>
<span class="n">batch</span><span class="o">.</span><span class="n">_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">continuous_actions</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">learn_from_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;WolpertingerAgent works only for discrete control problems&quot;</span><span class="p">)</span>
<span class="c1"># convert to batch so we can run it through the network</span>
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">&#39;actor&#39;</span><span class="p">)</span>
<span class="n">actor_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;actor&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
<span class="n">critic_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;critic&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
<span class="n">proto_action</span> <span class="o">=</span> <span class="n">actor_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)</span>
<span class="n">proto_action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">proto_action</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">nn_action_embeddings</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">knn_tree</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">keys</span><span class="o">=</span><span class="n">proto_action</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">k</span><span class="p">)</span>
<span class="c1"># now move the actions through the critic and choose the one with the highest q value</span>
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)</span>
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">&#39;observation&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">[</span><span class="s1">&#39;observation&#39;</span><span class="p">],</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">k</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">nn_action_embeddings</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">q_values</span> <span class="o">=</span> <span class="n">critic_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">action</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">indices</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">q_values</span><span class="p">)])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action_signal</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span> <span class="n">action_value</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">init_environment_dependent_modules</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_environment_dependent_modules</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">knn_tree</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_initialized_knn</span><span class="p">()</span>
<span class="c1"># TODO - ideally the knn should not be defined here, but somehow be defined by the user in the preset</span>
<span class="k">def</span> <span class="nf">get_initialized_knn</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">num_actions</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">actions</span><span class="p">)</span>
<span class="n">action_max_abs_range</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">filtered_action_space</span><span class="o">.</span><span class="n">max_abs_range</span> <span class="k">if</span> \
<span class="p">(</span><span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="s1">&#39;filtered_action_space&#39;</span><span class="p">)</span> <span class="ow">and</span>
<span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">filtered_action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">))</span> \
<span class="k">else</span> <span class="mf">1.0</span>
<span class="n">keys</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_actions</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">num_actions</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">action_max_abs_range</span>
<span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_actions</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">knn_tree</span> <span class="o">=</span> <span class="n">AnnoyDictionary</span><span class="p">(</span><span class="n">dict_size</span><span class="o">=</span><span class="n">num_actions</span><span class="p">,</span> <span class="n">key_width</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">action_embedding_width</span><span class="p">)</span>
<span class="n">knn_tree</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">force_rebuild_tree</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">return</span> <span class="n">knn_tree</span>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2018-2019, Intel AI Lab
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -396,6 +396,14 @@
<span class="c1"># Support for parameter noise</span>
<span class="bp">self</span><span class="o">.</span><span class="n">supports_parameter_noise</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># Override, in retrospective, all the episode rewards with the last reward in the episode</span>
<span class="c1"># (sometimes useful for sparse, end of the episode, rewards problems)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">override_episode_rewards_with_the_last_transition_reward</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># Filters - TODO consider creating a FilterParameters class and initialize the filters with it</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="o">=</span> <span class="kc">True</span>
<div class="viewcode-block" id="PresetValidationParameters"><a class="viewcode-back" href="../../components/additional_parameters.html#rl_coach.base_parameters.PresetValidationParameters">[docs]</a><span class="k">class</span> <span class="nc">PresetValidationParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>

View File

@@ -298,6 +298,12 @@
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_steps</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_steps</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">):</span>
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">/</span> <span class="n">other</span><span class="o">.</span><span class="n">num_steps</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">Time</span><span class="p">(</span><span class="n">StepMethod</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_steps</span><span class="p">):</span>

View File

@@ -200,15 +200,17 @@
<span class="kn">import</span> <span class="nn">uuid</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStore</span><span class="p">,</span> <span class="n">DataStoreParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStoreParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.checkpoint_data_store</span> <span class="k">import</span> <span class="n">CheckpointDataStore</span>
<span class="k">class</span> <span class="nc">NFSDataStoreParameters</span><span class="p">(</span><span class="n">DataStoreParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ds_params</span><span class="p">,</span> <span class="n">deployed</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">server</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">path</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ds_params</span><span class="p">,</span> <span class="n">deployed</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">server</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="o">=</span><span class="s2">&quot;&quot;</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">ds_params</span><span class="o">.</span><span class="n">store_type</span><span class="p">,</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_type</span><span class="p">,</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">namespace</span> <span class="o">=</span> <span class="s2">&quot;default&quot;</span>
<span class="k">if</span> <span class="s2">&quot;namespace&quot;</span> <span class="ow">in</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">namespace</span> <span class="o">=</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">[</span><span class="s2">&quot;namespace&quot;</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_dir</span> <span class="o">=</span> <span class="n">checkpoint_dir</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pvc_name</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pv_name</span> <span class="o">=</span> <span class="kc">None</span>
@@ -221,7 +223,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">path</span> <span class="o">=</span> <span class="n">path</span>
<div class="viewcode-block" id="NFSDataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.nfs_data_store.NFSDataStore">[docs]</a><span class="k">class</span> <span class="nc">NFSDataStore</span><span class="p">(</span><span class="n">DataStore</span><span class="p">):</span>
<div class="viewcode-block" id="NFSDataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.nfs_data_store.NFSDataStore">[docs]</a><span class="k">class</span> <span class="nc">NFSDataStore</span><span class="p">(</span><span class="n">CheckpointDataStore</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.</span>
<span class="sd"> The policy checkpoints are written by the trainer and read by the rollout worker.</span>

View File

@@ -198,7 +198,8 @@
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStore</span><span class="p">,</span> <span class="n">DataStoreParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStoreParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.checkpoint_data_store</span> <span class="k">import</span> <span class="n">CheckpointDataStore</span>
<span class="kn">from</span> <span class="nn">minio</span> <span class="k">import</span> <span class="n">Minio</span>
<span class="kn">from</span> <span class="nn">minio.error</span> <span class="k">import</span> <span class="n">ResponseError</span>
<span class="kn">from</span> <span class="nn">configparser</span> <span class="k">import</span> <span class="n">ConfigParser</span><span class="p">,</span> <span class="n">Error</span>
@@ -222,7 +223,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">expt_dir</span> <span class="o">=</span> <span class="n">expt_dir</span>
<div class="viewcode-block" id="S3DataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.s3_data_store.S3DataStore">[docs]</a><span class="k">class</span> <span class="nc">S3DataStore</span><span class="p">(</span><span class="n">DataStore</span><span class="p">):</span>
<div class="viewcode-block" id="S3DataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.s3_data_store.S3DataStore">[docs]</a><span class="k">class</span> <span class="nc">S3DataStore</span><span class="p">(</span><span class="n">CheckpointDataStore</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.</span>
<span class="sd"> The policy checkpoints are written by the trainer and read by the rollout worker.</span>

View File

@@ -245,7 +245,9 @@
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise</span> <span class="o">=</span> <span class="n">evaluation_noise</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span> <span class="o">=</span> <span class="n">noise_as_percentage_from_action_space</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)</span> <span class="ow">and</span> \
<span class="p">(</span><span class="nb">hasattr</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="s1">&#39;filtered_action_space&#39;</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span>
<span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="o">.</span><span class="n">filtered_action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Additive noise exploration works only for continuous controls.&quot;</span>
<span class="s2">&quot;The given action space is of type: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">action_space</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">))</span>

View File

@@ -298,7 +298,10 @@
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param action_space: the action space used by the environment</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">GoalsSpace</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)</span> <span class="ow">or</span> \
<span class="p">(</span><span class="nb">hasattr</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="s1">&#39;filtered_action_space&#39;</span><span class="p">)</span> <span class="ow">and</span>
<span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="o">.</span><span class="n">filtered_action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">))</span> <span class="ow">or</span> \
<span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">GoalsSpace</span><span class="p">)</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">action_space</span><span class="p">)</span>
</pre></div>

View File

@@ -271,9 +271,6 @@
<span class="k">else</span><span class="p">:</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise</span>
<span class="c1"># scale the noise to the action space range</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">high</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">)</span>
<span class="c1"># extract the mean values</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_values</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
<span class="c1"># the action values are expected to be a list with the action mean and optionally the action stdev</span>

View File

@@ -231,7 +231,8 @@
<span class="k">def</span> <span class="nf">get_unfiltered_action_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output_action_space</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DiscreteActionSpace</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_action_space</span> <span class="o">=</span> <span class="n">output_action_space</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_action_space</span> <span class="o">=</span> <span class="n">DiscreteActionSpace</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">target_actions</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_action_space</span> <span class="o">=</span> <span class="n">DiscreteActionSpace</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">target_actions</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">,</span>
<span class="n">filtered_action_space</span><span class="o">=</span><span class="n">output_action_space</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_action_space</span>
<span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActionType</span><span class="p">:</span>

View File

@@ -261,11 +261,18 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="s1">&#39;namespace&#39;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">[</span><span class="s1">&#39;namespace&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;default&quot;</span>
<span class="kn">from</span> <span class="nn">kubernetes</span> <span class="k">import</span> <span class="n">client</span>
<span class="kn">from</span> <span class="nn">kubernetes</span> <span class="k">import</span> <span class="n">client</span><span class="p">,</span> <span class="n">config</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">client</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">redis_server_name</span><span class="p">,</span>
<span class="n">image</span><span class="o">=</span><span class="s1">&#39;redis:4-alpine&#39;</span><span class="p">,</span>
<span class="n">resources</span><span class="o">=</span><span class="n">client</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
<span class="s2">&quot;cpu&quot;</span><span class="p">:</span> <span class="s2">&quot;8&quot;</span><span class="p">,</span>
<span class="s2">&quot;memory&quot;</span><span class="p">:</span> <span class="s2">&quot;4Gi&quot;</span>
<span class="c1"># &quot;nvidia.com/gpu&quot;: &quot;0&quot;,</span>
<span class="p">}</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="n">template</span> <span class="o">=</span> <span class="n">client</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
<span class="n">metadata</span><span class="o">=</span><span class="n">client</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;app&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">redis_server_name</span><span class="p">}),</span>
@@ -288,8 +295,10 @@
<span class="n">spec</span><span class="o">=</span><span class="n">deployment_spec</span>
<span class="p">)</span>
<span class="n">config</span><span class="o">.</span><span class="n">load_kube_config</span><span class="p">()</span>
<span class="n">api_client</span> <span class="o">=</span> <span class="n">client</span><span class="o">.</span><span class="n">AppsV1Api</span><span class="p">()</span>
<span class="k">try</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">)</span>
<span class="n">api_client</span><span class="o">.</span><span class="n">create_namespaced_deployment</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">[</span><span class="s1">&#39;namespace&#39;</span><span class="p">],</span> <span class="n">deployment</span><span class="p">)</span>
<span class="k">except</span> <span class="n">client</span><span class="o">.</span><span class="n">rest</span><span class="o">.</span><span class="n">ApiException</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while creating redis-server&quot;</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>

View File

@@ -240,7 +240,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">built_capacity</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">additional_data</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">additional_data</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">force_rebuild_tree</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">additional_data</span><span class="p">:</span>
<span class="n">additional_data</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">keys</span><span class="p">)</span>
@@ -279,7 +279,7 @@
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffered_indices</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_update_size</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_update_size</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">initial_update_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_size</span> <span class="o">*</span> <span class="mf">0.02</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_rebuild_index</span><span class="p">()</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">rebuild_on_every_update</span><span class="p">:</span>
<span class="k">elif</span> <span class="n">force_rebuild_tree</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">rebuild_on_every_update</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_rebuild_index</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_timestamp</span> <span class="o">+=</span> <span class="mi">1</span>

View File

@@ -307,6 +307,11 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">deploy</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;redis&quot;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_address</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_address</span>
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_port</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_port</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">deploy</span><span class="p">():</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;nfs&quot;</span><span class="p">:</span>
@@ -329,6 +334,8 @@
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--data_store_params&#39;</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
<span class="n">name</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
<span class="c1"># TODO: instead of defining each container and template spec from scratch, loaded default</span>
<span class="c1"># configuration and modify them as necessary depending on the store type</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;nfs&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
@@ -354,7 +361,7 @@
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;s3&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
<span class="n">image</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
@@ -373,6 +380,34 @@
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;redis&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
<span class="n">image</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
<span class="n">command</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span>
<span class="n">image_pull_policy</span><span class="o">=</span><span class="s1">&#39;Always&#39;</span><span class="p">,</span>
<span class="n">stdin</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">tty</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">resources</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
<span class="s2">&quot;cpu&quot;</span><span class="p">:</span> <span class="s2">&quot;40&quot;</span><span class="p">,</span>
<span class="s2">&quot;memory&quot;</span><span class="p">:</span> <span class="s2">&quot;4Gi&quot;</span><span class="p">,</span>
<span class="s2">&quot;nvidia.com/gpu&quot;</span><span class="p">:</span> <span class="s2">&quot;1&quot;</span><span class="p">,</span>
<span class="p">}</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="n">template</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;app&#39;</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;unexpected store_type </span><span class="si">{}</span><span class="s2">. expected &#39;s3&#39;, &#39;nfs&#39;, &#39;redis&#39;&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span>
<span class="p">))</span>
<span class="n">job_spec</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1JobSpec</span><span class="p">(</span>
<span class="n">completions</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
@@ -404,12 +439,17 @@
<span class="k">if</span> <span class="ow">not</span> <span class="n">worker_params</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="c1"># At this point, the memory backend and data store have been deployed and in the process,</span>
<span class="c1"># these parameters have been updated to include things like the hostname and port the</span>
<span class="c1"># service can be found at.</span>
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--memory_backend_params&#39;</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">memory_backend_parameters</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--data_store_params&#39;</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--num_workers&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">worker_params</span><span class="o">.</span><span class="n">num_replicas</span><span class="p">)]</span>
<span class="n">name</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">worker_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
<span class="c1"># TODO: instead of defining each container and template spec from scratch, loaded default</span>
<span class="c1"># configuration and modify them as necessary depending on the store type</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;nfs&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
@@ -435,7 +475,7 @@
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;s3&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
<span class="n">image</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
@@ -454,6 +494,32 @@
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;redis&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
<span class="n">image</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
<span class="n">command</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">command</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span>
<span class="n">image_pull_policy</span><span class="o">=</span><span class="s1">&#39;Always&#39;</span><span class="p">,</span>
<span class="n">stdin</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">tty</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">resources</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
<span class="s2">&quot;cpu&quot;</span><span class="p">:</span> <span class="s2">&quot;8&quot;</span><span class="p">,</span>
<span class="s2">&quot;memory&quot;</span><span class="p">:</span> <span class="s2">&quot;4Gi&quot;</span><span class="p">,</span>
<span class="c1"># &quot;nvidia.com/gpu&quot;: &quot;0&quot;,</span>
<span class="p">}</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="n">template</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;app&#39;</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;unexpected store type </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span><span class="p">))</span>
<span class="n">job_spec</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1JobSpec</span><span class="p">(</span>
<span class="n">completions</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">num_replicas</span><span class="p">,</span>

View File

@@ -568,7 +568,8 @@
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> A discrete action space with action indices as actions</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_actions</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">default_action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_actions</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">default_action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">filtered_action_space</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">num_actions</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descriptions</span><span class="o">=</span><span class="n">descriptions</span><span class="p">)</span>
<span class="c1"># the number of actions is mapped to high</span>
@@ -578,6 +579,9 @@
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="n">default_action</span>
<span class="k">if</span> <span class="n">filtered_action_space</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">filtered_action_space</span> <span class="o">=</span> <span class="n">filtered_action_space</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">actions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">]:</span>
<span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>

View File

@@ -21,8 +21,6 @@ A detailed description of those algorithms can be found by navigating to each of
imitation/cil
policy_optimization/cppo
policy_optimization/ddpg
policy_optimization/td3
policy_optimization/sac
other/dfp
value_optimization/double_dqn
value_optimization/dqn
@@ -36,6 +34,10 @@ A detailed description of those algorithms can be found by navigating to each of
policy_optimization/ppo
value_optimization/rainbow
value_optimization/qr_dqn
policy_optimization/sac
policy_optimization/td3
policy_optimization/wolpertinger
.. autoclass:: rl_coach.base_parameters.AgentParameters

View File

@@ -0,0 +1,56 @@
Wolpertinger
=============
**Actions space:** Discrete
**References:** `Deep Reinforcement Learning in Large Discrete Action Spaces <https://arxiv.org/abs/1512.07679>`_
Network Structure
-----------------
.. image:: /_static/img/design_imgs/wolpertinger.png
:align: center
Algorithm Description
---------------------
Choosing an action
++++++++++++++++++
Pass the current states through the actor network, and get a proto action :math:`\mu`.
While in training phase, use a continuous exploration policy, such as the a gaussian noise,
to add exploration noise to the proto action. Then, pass the proto action to a k-NN tree to find actual valid
action candidates, which are in the surrounding neighborhood of the proto action. Those actions are then passed to the
critic to evaluate their goodness, and eventually the discrete index of the action with the highest Q value is chosen.
When testing, the same flow is used, but no exploration noise is added.
Training the network
++++++++++++++++++++
Training the network is exactly the same as in DDPG. Unlike when choosing the action, the proto action is not passed
through the k-NN tree. It is being passed directly to the critic.
Start by sampling a batch of transitions from the experience replay.
* To train the **critic network**, use the following targets:
:math:`y_t=r(s_t,a_t )+\gamma \cdot Q(s_{t+1},\mu(s_{t+1} ))`
First run the actor target network, using the next states as the inputs, and get :math:`\mu (s_{t+1} )`.
Next, run the critic target network using the next states and :math:`\mu (s_{t+1} )`, and use the output to
calculate :math:`y_t` according to the equation above. To train the network, use the current states and actions
as the inputs, and :math:`y_t` as the targets.
* To train the **actor network**, use the following equation:
:math:`\nabla_{\theta^\mu } J \approx E_{s_t \tilde{} \rho^\beta } [\nabla_a Q(s,a)|_{s=s_t,a=\mu (s_t ) } \cdot \nabla_{\theta^\mu} \mu(s)|_{s=s_t} ]`
Use the actor's online network to get the action mean values using the current states as the inputs.
Then, use the critic online network in order to get the gradients of the critic output with respect to the
action mean values :math:`\nabla _a Q(s,a)|_{s=s_t,a=\mu(s_t ) }`.
Using the chain rule, calculate the gradients of the actor's output, with respect to the actor weights,
given :math:`\nabla_a Q(s,a)`. Finally, apply those gradients to the actor network.
After every training step, do a soft update of the critic and actor target networks' weights from the online networks.
.. autoclass:: rl_coach.agents.wolpertinger_agent.WolpertingerAlgorithmParameters

View File

@@ -117,8 +117,6 @@
<li class="toctree-l2"><a class="reference internal" href="imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -132,6 +130,9 @@
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/ppo.html">Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="value_optimization/rainbow.html">Rainbow</a></li>
<li class="toctree-l2"><a class="reference internal" href="value_optimization/qr_dqn.html">Quantile Regression DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/wolpertinger.html">Wolpertinger</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../architectures/index.html">Architectures</a></li>
@@ -226,8 +227,6 @@ A detailed description of those algorithms can be found by navigating to each of
<li class="toctree-l1"><a class="reference internal" href="imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l1"><a class="reference internal" href="other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l1"><a class="reference internal" href="value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l1"><a class="reference internal" href="value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -241,6 +240,9 @@ A detailed description of those algorithms can be found by navigating to each of
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/ppo.html">Proximal Policy Optimization</a></li>
<li class="toctree-l1"><a class="reference internal" href="value_optimization/rainbow.html">Rainbow</a></li>
<li class="toctree-l1"><a class="reference internal" href="value_optimization/qr_dqn.html">Quantile Regression DQN</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/wolpertinger.html">Wolpertinger</a></li>
</ul>
</div>
<dl class="class">
@@ -512,7 +514,7 @@ given observation</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.prepare_batch_for_inference">
<code class="sig-name descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em class="sig-param">states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.array]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.prepare_batch_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em class="sig-param">states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.core.multiarray.array]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.prepare_batch_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.</p>
<dl class="field-list simple">
@@ -652,7 +654,7 @@ dependent on those values, by calling init_environment_dependent_modules</p>
<code class="sig-name descname">set_incoming_directive</code><span class="sig-paren">(</span><em class="sig-param">action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.set_incoming_directive"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.set_incoming_directive" title="Permalink to this definition"></a></dt>
<dd><p>Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent
has another master agent that is controlling it. In such cases, the master agent can define the goals for the
slave agent, define its observation, possible actions, etc. The directive type is defined by the agent
slave agent, define its observation, possible actions, etc. The directive type is defined by the agent
in-action-space.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>

View File

@@ -0,0 +1,276 @@
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wolpertinger &mdash; Reinforcement Learning Coach 0.12.0 documentation</title>
<script type="text/javascript" src="../../../_static/js/modernizr.min.js"></script>
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Intro</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
</ul>
<p class="caption"><span class="caption-text">Design</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
</ul>
<p class="caption"><span class="caption-text">Contributing</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
</ul>
<p class="caption"><span class="caption-text">Components</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Agents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architectures/index.html">Architectures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../data_stores/index.html">Data Stores</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../environments/index.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../exploration_policies/index.html">Exploration Policies</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../filters/index.html">Filters</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../memories/index.html">Memories</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../memory_backends/index.html">Memory Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../orchestrators/index.html">Orchestrators</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../core_types.html">Core Types</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../spaces.html">Spaces</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../additional_parameters.html">Additional Parameters</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">Reinforcement Learning Coach</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html">Docs</a> &raquo;</li>
<li>Wolpertinger</li>
<li class="wy-breadcrumbs-aside">
<a href="../../../_sources/components/agents/policy_optimization/wolpertinger.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="section" id="wolpertinger">
<h1>Wolpertinger<a class="headerlink" href="#wolpertinger" title="Permalink to this headline"></a></h1>
<p><strong>Actions space:</strong> Discrete</p>
<p><strong>References:</strong> <a class="reference external" href="https://arxiv.org/abs/1512.07679">Deep Reinforcement Learning in Large Discrete Action Spaces</a></p>
<div class="section" id="network-structure">
<h2>Network Structure<a class="headerlink" href="#network-structure" title="Permalink to this headline"></a></h2>
<img alt="../../../_images/wolpertinger.png" class="align-center" src="../../../_images/wolpertinger.png" />
</div>
<div class="section" id="algorithm-description">
<h2>Algorithm Description<a class="headerlink" href="#algorithm-description" title="Permalink to this headline"></a></h2>
<div class="section" id="choosing-an-action">
<h3>Choosing an action<a class="headerlink" href="#choosing-an-action" title="Permalink to this headline"></a></h3>
<p>Pass the current states through the actor network, and get a proto action <span class="math notranslate nohighlight">\(\mu\)</span>.
While in training phase, use a continuous exploration policy, such as the a gaussian noise,
to add exploration noise to the proto action. Then, pass the proto action to a k-NN tree to find actual valid
action candidates, which are in the surrounding neighborhood of the proto action. Those actions are then passed to the
critic to evaluate their goodness, and eventually the discrete index of the action with the highest Q value is chosen.
When testing, the same flow is used, but no exploration noise is added.</p>
</div>
<div class="section" id="training-the-network">
<h3>Training the network<a class="headerlink" href="#training-the-network" title="Permalink to this headline"></a></h3>
<p>Training the network is exactly the same as in DDPG. Unlike when choosing the action, the proto action is not passed
through the k-NN tree. It is being passed directly to the critic.</p>
<p>Start by sampling a batch of transitions from the experience replay.</p>
<ul>
<li><p>To train the <strong>critic network</strong>, use the following targets:</p>
<p><span class="math notranslate nohighlight">\(y_t=r(s_t,a_t )+\gamma \cdot Q(s_{t+1},\mu(s_{t+1} ))\)</span></p>
<p>First run the actor target network, using the next states as the inputs, and get <span class="math notranslate nohighlight">\(\mu (s_{t+1} )\)</span>.
Next, run the critic target network using the next states and <span class="math notranslate nohighlight">\(\mu (s_{t+1} )\)</span>, and use the output to
calculate <span class="math notranslate nohighlight">\(y_t\)</span> according to the equation above. To train the network, use the current states and actions
as the inputs, and <span class="math notranslate nohighlight">\(y_t\)</span> as the targets.</p>
</li>
<li><p>To train the <strong>actor network</strong>, use the following equation:</p>
<p><span class="math notranslate nohighlight">\(\nabla_{\theta^\mu } J \approx E_{s_t \tilde{} \rho^\beta } [\nabla_a Q(s,a)|_{s=s_t,a=\mu (s_t ) } \cdot \nabla_{\theta^\mu} \mu(s)|_{s=s_t} ]\)</span></p>
<p>Use the actors online network to get the action mean values using the current states as the inputs.
Then, use the critic online network in order to get the gradients of the critic output with respect to the
action mean values <span class="math notranslate nohighlight">\(\nabla _a Q(s,a)|_{s=s_t,a=\mu(s_t ) }\)</span>.
Using the chain rule, calculate the gradients of the actors output, with respect to the actor weights,
given <span class="math notranslate nohighlight">\(\nabla_a Q(s,a)\)</span>. Finally, apply those gradients to the actor network.</p>
</li>
</ul>
<p>After every training step, do a soft update of the critic and actor target networks weights from the online networks.</p>
<dl class="class">
<dt id="rl_coach.agents.wolpertinger_agent.WolpertingerAlgorithmParameters">
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.wolpertinger_agent.</code><code class="sig-name descname">WolpertingerAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/wolpertinger_agent.html#WolpertingerAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.wolpertinger_agent.WolpertingerAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd></dd></dl>
</div>
</div>
</div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2018-2019, Intel AI Lab
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -442,7 +442,7 @@ The actions will be in the form:
<h3>DiscreteActionSpace<a class="headerlink" href="#discreteactionspace" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.spaces.DiscreteActionSpace">
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">DiscreteActionSpace</code><span class="sig-paren">(</span><em class="sig-param">num_actions: int</em>, <em class="sig-param">descriptions: Union[None</em>, <em class="sig-param">List</em>, <em class="sig-param">Dict] = None</em>, <em class="sig-param">default_action: numpy.ndarray = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#DiscreteActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.DiscreteActionSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">DiscreteActionSpace</code><span class="sig-paren">(</span><em class="sig-param">num_actions: int</em>, <em class="sig-param">descriptions: Union[None</em>, <em class="sig-param">List</em>, <em class="sig-param">Dict] = None</em>, <em class="sig-param">default_action: numpy.ndarray = None</em>, <em class="sig-param">filtered_action_space=None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#DiscreteActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.DiscreteActionSpace" title="Permalink to this definition"></a></dt>
<dd><p>A discrete action space with action indices as actions</p>
</dd></dl>

View File

@@ -37,7 +37,7 @@
<link rel="stylesheet" href="../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Batch Reinforcement Learning" href="batch_rl.html" />
<link rel="next" title="Selecting an Algorithm" href="../selecting_an_algorithm.html" />
<link rel="prev" title="Environments" href="environments.html" />
<link href="../_static/css/custom.css" rel="stylesheet" type="text/css">
@@ -95,7 +95,6 @@
<li class="toctree-l2"><a class="reference internal" href="algorithms.html">Algorithms</a></li>
<li class="toctree-l2"><a class="reference internal" href="environments.html">Environments</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Benchmarks</a></li>
<li class="toctree-l2"><a class="reference internal" href="batch_rl.html">Batch Reinforcement Learning</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
@@ -221,7 +220,7 @@ benchmarks stay intact as Coach continues to develop.</p>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="batch_rl.html" class="btn btn-neutral float-right" title="Batch Reinforcement Learning" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="../selecting_an_algorithm.html" class="btn btn-neutral float-right" title="Selecting an Algorithm" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="environments.html" class="btn btn-neutral float-left" title="Environments" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>

View File

@@ -95,7 +95,6 @@
<li class="toctree-l2"><a class="reference internal" href="algorithms.html">Algorithms</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Environments</a></li>
<li class="toctree-l2"><a class="reference internal" href="benchmarks.html">Benchmarks</a></li>
<li class="toctree-l2"><a class="reference internal" href="batch_rl.html">Batch Reinforcement Learning</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../selecting_an_algorithm.html">Selecting an Algorithm</a></li>

View File

@@ -206,6 +206,7 @@
| <a href="#T"><strong>T</strong></a>
| <a href="#U"><strong>U</strong></a>
| <a href="#V"><strong>V</strong></a>
| <a href="#W"><strong>W</strong></a>
</div>
<h2 id="A">A</h2>
@@ -956,6 +957,14 @@
</ul></td>
</tr></table>
<h2 id="W">W</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/agents/policy_optimization/wolpertinger.html#rl_coach.agents.wolpertinger_agent.WolpertingerAlgorithmParameters">WolpertingerAlgorithmParameters (class in rl_coach.agents.wolpertinger_agent)</a>
</li>
</ul></td>
</tr></table>
</div>

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -439,7 +439,7 @@ given observation</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference">
<code class="sig-name descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em class="sig-param">states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em class="sig-param">states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.core.multiarray.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.</p>
<dl class="field-list simple">