mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
886 lines
128 KiB
HTML
886 lines
128 KiB
HTML
|
|
|
|
<!DOCTYPE html>
|
|
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
|
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
|
<head>
|
|
<meta charset="utf-8">
|
|
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
|
|
<title>rl_coach.spaces — Reinforcement Learning Coach 0.12.0 documentation</title>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<script type="text/javascript" src="../../_static/js/modernizr.min.js"></script>
|
|
|
|
|
|
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
|
<script type="text/javascript" src="../../_static/jquery.js"></script>
|
|
<script type="text/javascript" src="../../_static/underscore.js"></script>
|
|
<script type="text/javascript" src="../../_static/doctools.js"></script>
|
|
<script type="text/javascript" src="../../_static/language_data.js"></script>
|
|
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
|
|
|
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
|
|
|
|
|
|
|
|
|
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
|
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
|
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
|
|
<link rel="index" title="Index" href="../../genindex.html" />
|
|
<link rel="search" title="Search" href="../../search.html" />
|
|
<link href="../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
|
|
|
</head>
|
|
|
|
<body class="wy-body-for-nav">
|
|
|
|
|
|
<div class="wy-grid-for-nav">
|
|
|
|
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
|
<div class="wy-side-scroll">
|
|
<div class="wy-side-nav-search" >
|
|
|
|
|
|
|
|
<a href="../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
|
|
|
|
|
|
|
|
|
<img src="../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
|
|
|
</a>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<div role="search">
|
|
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
|
<input type="text" name="q" placeholder="Search docs" />
|
|
<input type="hidden" name="check_keywords" value="yes" />
|
|
<input type="hidden" name="area" value="default" />
|
|
</form>
|
|
</div>
|
|
|
|
|
|
</div>
|
|
|
|
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<p class="caption"><span class="caption-text">Intro</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../usage.html">Usage</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../dist_usage.html">Usage - Distributed Coach</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../features/index.html">Features</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../dashboard.html">Coach Dashboard</a></li>
|
|
</ul>
|
|
<p class="caption"><span class="caption-text">Design</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../design/control_flow.html">Control Flow</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../design/network.html">Network Design</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
|
|
</ul>
|
|
<p class="caption"><span class="caption-text">Contributing</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../contributing/add_agent.html">Adding a New Agent</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../contributing/add_env.html">Adding a New Environment</a></li>
|
|
</ul>
|
|
<p class="caption"><span class="caption-text">Components</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/agents/index.html">Agents</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/architectures/index.html">Architectures</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/data_stores/index.html">Data Stores</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/environments/index.html">Environments</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/filters/index.html">Filters</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/memories/index.html">Memories</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/memory_backends/index.html">Memory Backends</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/orchestrators/index.html">Orchestrators</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/core_types.html">Core Types</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/spaces.html">Spaces</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../components/additional_parameters.html">Additional Parameters</a></li>
|
|
</ul>
|
|
|
|
|
|
|
|
</div>
|
|
</div>
|
|
</nav>
|
|
|
|
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
|
|
|
|
|
<nav class="wy-nav-top" aria-label="top navigation">
|
|
|
|
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
|
<a href="../../index.html">Reinforcement Learning Coach</a>
|
|
|
|
</nav>
|
|
|
|
|
|
<div class="wy-nav-content">
|
|
|
|
<div class="rst-content">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<div role="navigation" aria-label="breadcrumbs navigation">
|
|
|
|
<ul class="wy-breadcrumbs">
|
|
|
|
<li><a href="../../index.html">Docs</a> »</li>
|
|
|
|
<li><a href="../index.html">Module code</a> »</li>
|
|
|
|
<li>rl_coach.spaces</li>
|
|
|
|
|
|
<li class="wy-breadcrumbs-aside">
|
|
|
|
</li>
|
|
|
|
</ul>
|
|
|
|
|
|
<hr/>
|
|
</div>
|
|
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
|
<div itemprop="articleBody">
|
|
|
|
<h1>Source code for rl_coach.spaces</h1><div class="highlight"><pre>
|
|
<span></span><span class="c1">#</span>
|
|
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
|
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
|
<span class="c1"># You may obtain a copy of the License at</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
|
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
|
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
|
<span class="c1"># See the License for the specific language governing permissions and</span>
|
|
<span class="c1"># limitations under the License.</span>
|
|
<span class="c1">#</span>
|
|
|
|
<span class="kn">import</span> <span class="nn">random</span>
|
|
<span class="kn">from</span> <span class="nn">enum</span> <span class="k">import</span> <span class="n">Enum</span>
|
|
<span class="kn">from</span> <span class="nn">itertools</span> <span class="k">import</span> <span class="n">product</span>
|
|
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Callable</span>
|
|
|
|
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
|
<span class="kn">import</span> <span class="nn">scipy</span>
|
|
<span class="kn">import</span> <span class="nn">scipy.spatial</span>
|
|
|
|
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionType</span><span class="p">,</span> <span class="n">ActionInfo</span>
|
|
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">eps</span>
|
|
|
|
|
|
<div class="viewcode-block" id="Space"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space">[docs]</a><span class="k">class</span> <span class="nc">Space</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> A space defines a set of valid values</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">low</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
|
|
<span class="n">high</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> :param shape: the shape of the space</span>
|
|
<span class="sd"> :param low: the lowest values possible in the space. can be an array defining the lowest values per point,</span>
|
|
<span class="sd"> or a single value defining the general lowest values</span>
|
|
<span class="sd"> :param high: the highest values possible in the space. can be an array defining the highest values per point,</span>
|
|
<span class="sd"> or a single value defining the general highest values</span>
|
|
<span class="sd"> """</span>
|
|
|
|
<span class="c1"># the number of dimensions is the number of axes in the shape. it will be set in the shape setter</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_dimensions</span> <span class="o">=</span> <span class="mi">0</span>
|
|
|
|
<span class="c1"># the number of elements is the number of possible actions if the action space was discrete.</span>
|
|
<span class="c1"># it will be set in the shape setter</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_elements</span> <span class="o">=</span> <span class="mi">0</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_low</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_high</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_low</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span> <span class="o">=</span> <span class="n">low</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_high</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span> <span class="o">=</span> <span class="n">high</span>
|
|
|
|
<span class="c1"># we allow zero sized spaces which means that the space is empty. this is useful for environments with no</span>
|
|
<span class="c1"># measurements for example.</span>
|
|
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="nb">int</span> <span class="ow">and</span> <span class="n">shape</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The shape of the space must be a non-negative number"</span><span class="p">)</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shape</span>
|
|
|
|
<span class="nd">@shape</span><span class="o">.</span><span class="n">setter</span>
|
|
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]):</span>
|
|
<span class="c1"># convert the shape to an np.ndarray</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">val</span>
|
|
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">)</span> <span class="o">==</span> <span class="nb">int</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">])</span>
|
|
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">)</span> <span class="o">==</span> <span class="nb">tuple</span> <span class="ow">or</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">)</span> <span class="o">==</span> <span class="nb">list</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">)</span>
|
|
|
|
<span class="c1"># the shape is now an np.ndarray</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_dimensions</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_elements</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">))</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">low</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">'_low'</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_low</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="kc">None</span>
|
|
|
|
<span class="nd">@low</span><span class="o">.</span><span class="n">setter</span>
|
|
<span class="k">def</span> <span class="nf">low</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]):</span>
|
|
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">val</span><span class="p">)</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="ow">and</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="ow">and</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The low values shape don't match the shape of the space"</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span> <span class="o">>=</span> <span class="n">val</span><span class="p">):</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"At least one of the axes-parallel lines defining the space has high values which "</span>
|
|
<span class="s2">"are lower than the given low values"</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_low</span> <span class="o">=</span> <span class="n">val</span>
|
|
<span class="c1"># we allow using a number to define the low values, but we immediately convert it to an array which defines</span>
|
|
<span class="c1"># the low values for all the space dimensions in order to expose a consistent value type</span>
|
|
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_low</span><span class="p">)</span> <span class="o">==</span> <span class="nb">int</span> <span class="ow">or</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_low</span><span class="p">)</span> <span class="o">==</span> <span class="nb">float</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_low</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">_low</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">high</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">'_high'</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_high</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="kc">None</span>
|
|
|
|
<span class="nd">@high</span><span class="o">.</span><span class="n">setter</span>
|
|
<span class="k">def</span> <span class="nf">high</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]):</span>
|
|
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">val</span><span class="p">)</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="ow">and</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="ow">and</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The high values shape don't match the shape of the space"</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span> <span class="o"><=</span> <span class="n">val</span><span class="p">):</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"At least one of the axes-parallel lines defining the space has low values which "</span>
|
|
<span class="s2">"are higher than the given high values"</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_high</span> <span class="o">=</span> <span class="n">val</span>
|
|
<span class="c1"># we allow using a number to define the high values, but we immediately convert it to an array which defines</span>
|
|
<span class="c1"># the high values for all the space dimensions in order to expose a consistent value type</span>
|
|
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_high</span><span class="p">)</span> <span class="o">==</span> <span class="nb">int</span> <span class="ow">or</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_high</span><span class="p">)</span> <span class="o">==</span> <span class="nb">float</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_high</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">_high</span>
|
|
|
|
<div class="viewcode-block" id="Space.contains"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.contains">[docs]</a> <span class="k">def</span> <span class="nf">contains</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">])</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Checks if value is contained by this space. The shape must match and</span>
|
|
<span class="sd"> all of the values must be within the low and high bounds.</span>
|
|
|
|
<span class="sd"> :param val: a value to check</span>
|
|
<span class="sd"> :return: True / False depending on if the val matches the space definition</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">if</span> <span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">val</span><span class="p">)</span> <span class="o">==</span> <span class="nb">int</span> <span class="ow">or</span> <span class="nb">type</span><span class="p">(</span><span class="n">val</span><span class="p">)</span> <span class="o">==</span> <span class="nb">float</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">)):</span>
|
|
<span class="k">return</span> <span class="kc">False</span>
|
|
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">val</span><span class="p">)</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">val</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="kc">False</span>
|
|
<span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">val</span> <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">))</span> \
|
|
<span class="ow">or</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">val</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">)):</span>
|
|
<span class="c1"># TODO: check the performance overhead this causes</span>
|
|
<span class="k">return</span> <span class="kc">False</span>
|
|
<span class="k">return</span> <span class="kc">True</span></div>
|
|
|
|
<div class="viewcode-block" id="Space.is_valid_index"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.is_valid_index">[docs]</a> <span class="k">def</span> <span class="nf">is_valid_index</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Checks if a given multidimensional index is within the bounds of the shape of the space</span>
|
|
|
|
<span class="sd"> :param index: a multidimensional index</span>
|
|
<span class="sd"> :return: True if the index is within the shape of the space. False otherwise</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">index</span><span class="p">)</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_dimensions</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="kc">False</span>
|
|
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">index</span> <span class="o"><</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_dimensions</span><span class="p">))</span> <span class="ow">or</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">index</span> <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="kc">False</span>
|
|
<span class="k">return</span> <span class="kc">True</span></div>
|
|
|
|
<div class="viewcode-block" id="Space.sample"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.sample">[docs]</a> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Sample the defined space, either uniformly, if space bounds are defined, or Normal distributed if no</span>
|
|
<span class="sd"> bounds are defined</span>
|
|
|
|
<span class="sd"> :return: A numpy array sampled from the space</span>
|
|
<span class="sd"> """</span>
|
|
<span class="c1"># if there are infinite bounds, we sample using gaussian noise with mean 0 and std 1</span>
|
|
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span> <span class="o">==</span> <span class="o">-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">)</span> <span class="ow">or</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></div>
|
|
|
|
<span class="k">def</span> <span class="nf">val_matches_space_definition</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">])</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
|
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
|
<span class="s2">"Space.val_matches_space_definition will be deprecated soon. Use "</span>
|
|
<span class="s2">"contains instead."</span>
|
|
<span class="p">)</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">contains</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span> <span class="nf">is_point_in_space_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">point</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
|
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
|
<span class="s2">"Space.is_point_in_space_shape will be deprecated soon. Use "</span>
|
|
<span class="s2">"is_valid_index instead."</span>
|
|
<span class="p">)</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_valid_index</span><span class="p">(</span><span class="n">point</span><span class="p">)</span></div>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">RewardSpace</span><span class="p">(</span><span class="n">Space</span><span class="p">):</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">low</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
|
|
<span class="n">high</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
|
|
<span class="n">reward_success_threshold</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">reward_success_threshold</span> <span class="o">=</span> <span class="n">reward_success_threshold</span>
|
|
|
|
|
|
<span class="sd">"""</span>
|
|
<span class="sd">Observation Spaces</span>
|
|
<span class="sd">"""</span>
|
|
|
|
|
|
<div class="viewcode-block" id="ObservationSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.ObservationSpace">[docs]</a><span class="k">class</span> <span class="nc">ObservationSpace</span><span class="p">(</span><span class="n">Space</span><span class="p">):</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">low</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
|
|
<span class="n">high</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">)</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="VectorObservationSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.VectorObservationSpace">[docs]</a><span class="k">class</span> <span class="nc">VectorObservationSpace</span><span class="p">(</span><span class="n">ObservationSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> An observation space which is defined as a vector of elements. This can be particularly useful for environments</span>
|
|
<span class="sd"> which return measurements, such as in robotic environments.</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">low</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
|
|
<span class="n">high</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span> <span class="n">measurements_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="n">measurements_names</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">measurements_names</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">measurements_names</span><span class="p">)</span> <span class="o">></span> <span class="n">shape</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"measurement_names size </span><span class="si">{}</span><span class="s2"> is larger than shape </span><span class="si">{}</span><span class="s2">."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
|
|
<span class="nb">len</span><span class="p">(</span><span class="n">measurements_names</span><span class="p">),</span> <span class="n">shape</span><span class="p">))</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">measurements_names</span> <span class="o">=</span> <span class="n">measurements_names</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">)</span></div>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">TensorObservationSpace</span><span class="p">(</span><span class="n">ObservationSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> An observation space which defines observations with arbitrary shape. This can be particularly useful for</span>
|
|
<span class="sd"> environments with non image input.</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">low</span><span class="p">:</span> <span class="o">-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
|
|
<span class="n">high</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">)</span>
|
|
|
|
|
|
<div class="viewcode-block" id="PlanarMapsObservationSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.PlanarMapsObservationSpace">[docs]</a><span class="k">class</span> <span class="nc">PlanarMapsObservationSpace</span><span class="p">(</span><span class="n">ObservationSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> An observation space which defines a stack of 2D observations. For example, an environment which returns</span>
|
|
<span class="sd"> a stack of segmentation maps like in Starcraft.</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">low</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">high</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels_axis</span><span class="p">:</span> <span class="nb">int</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">channels_axis</span> <span class="o">=</span> <span class="n">channels_axis</span>
|
|
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="mi">2</span> <span class="o"><=</span> <span class="nb">len</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span> <span class="o"><=</span> <span class="mi">3</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Planar maps observations must have 3 dimensions - a channels dimension and 2 maps "</span>
|
|
<span class="s2">"dimensions, not </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shape</span><span class="p">)))</span>
|
|
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="mi">1</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">shape</span><span class="p">[</span><span class="n">channels_axis</span><span class="p">]</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="ImageObservationSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.ImageObservationSpace">[docs]</a><span class="k">class</span> <span class="nc">ImageObservationSpace</span><span class="p">(</span><span class="n">PlanarMapsObservationSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> An observation space which is a private case of the PlanarMapsObservationSpace, where the stack of 2D observations</span>
|
|
<span class="sd"> represent a RGB image, or a grayscale image.</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">high</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels_axis</span><span class="p">:</span> <span class="nb">int</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
|
|
<span class="c1"># TODO: consider allowing arbitrary low values for images</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="p">,</span> <span class="n">channels_axis</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">has_colors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="mi">3</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="mi">3</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Image observations must have 1 or 3 channels, not </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">))</span></div>
|
|
|
|
|
|
<span class="c1"># TODO: mixed observation spaces (image + measurements, image + segmentation + depth map, etc.)</span>
|
|
<span class="k">class</span> <span class="nc">StateSpace</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sub_spaces</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Space</span><span class="p">]):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sub_spaces</span> <span class="o">=</span> <span class="n">sub_spaces</span>
|
|
|
|
<span class="k">def</span> <span class="nf">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">item</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sub_spaces</span><span class="p">[</span><span class="n">item</span><span class="p">]</span>
|
|
|
|
<span class="k">def</span> <span class="nf">__setitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sub_spaces</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
|
|
|
|
|
|
<span class="sd">"""</span>
|
|
<span class="sd">Action Spaces</span>
|
|
<span class="sd">"""</span>
|
|
|
|
|
|
<div class="viewcode-block" id="ActionSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.ActionSpace">[docs]</a><span class="k">class</span> <span class="nc">ActionSpace</span><span class="p">(</span><span class="n">Space</span><span class="p">):</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">low</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
|
|
<span class="n">high</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">default_action</span><span class="p">:</span> <span class="n">ActionType</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">)</span>
|
|
<span class="c1"># we allow a mismatch between the number of descriptions and the number of actions.</span>
|
|
<span class="c1"># in this case the descriptions for the actions that were not given will be the action index</span>
|
|
<span class="k">if</span> <span class="n">descriptions</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span> <span class="o">=</span> <span class="n">descriptions</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="n">default_action</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">actions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">]:</span>
|
|
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">"The action space does not have an explicit actions list"</span><span class="p">)</span>
|
|
|
|
<div class="viewcode-block" id="ActionSpace.sample_with_info"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.ActionSpace.sample_with_info">[docs]</a> <span class="k">def</span> <span class="nf">sample_with_info</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionInfo</span><span class="p">:</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Get a random action with additional "fake" info</span>
|
|
|
|
<span class="sd"> :return: An action info instance</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">return</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">())</span></div>
|
|
|
|
<div class="viewcode-block" id="ActionSpace.clip_action_to_space"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.ActionSpace.clip_action_to_space">[docs]</a> <span class="k">def</span> <span class="nf">clip_action_to_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionType</span><span class="p">:</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Given an action, clip its values to fit to the action space ranges</span>
|
|
|
|
<span class="sd"> :param action: a given action</span>
|
|
<span class="sd"> :return: the clipped action</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">return</span> <span class="n">action</span></div>
|
|
|
|
<span class="k">def</span> <span class="nf">get_description</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span> <span class="nf">__str__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="s2">"</span><span class="si">{}</span><span class="s2">: shape = </span><span class="si">{}</span><span class="s2">, low = </span><span class="si">{}</span><span class="s2">, high = </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span> <span class="nf">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__str__</span><span class="p">()</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="AttentionActionSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.AttentionActionSpace">[docs]</a><span class="k">class</span> <span class="nc">AttentionActionSpace</span><span class="p">(</span><span class="n">ActionSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> A box selection continuous action space, meaning that the actions are defined as selecting a multidimensional box</span>
|
|
<span class="sd"> from a given range.</span>
|
|
<span class="sd"> The actions will be in the form:</span>
|
|
<span class="sd"> [[low_x, low_y, ...], [high_x, high_y, ...]]</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">low</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
|
|
<span class="n">high</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">default_action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">forced_attention_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span> <span class="o">=</span> <span class="n">forced_attention_size</span>
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span> <span class="o">></span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">)):</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The forced attention size is larger than the action space"</span><span class="p">)</span>
|
|
|
|
<span class="c1"># default action</span>
|
|
<span class="k">if</span> <span class="n">default_action</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="o">+</span><span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span><span class="p">)</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="n">default_action</span>
|
|
|
|
<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">sampled_low</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">sampled_high</span> <span class="o">=</span> <span class="n">sampled_low</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">forced_attention_size</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">sampled_low</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">sampled_high</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">sampled_low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="p">[</span><span class="n">sampled_low</span><span class="p">,</span> <span class="n">sampled_high</span><span class="p">]</span>
|
|
|
|
<span class="k">def</span> <span class="nf">clip_action_to_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionType</span><span class="p">:</span>
|
|
<span class="n">action</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">action</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">action</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">)]</span>
|
|
<span class="k">return</span> <span class="n">action</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="BoxActionSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.BoxActionSpace">[docs]</a><span class="k">class</span> <span class="nc">BoxActionSpace</span><span class="p">(</span><span class="n">ActionSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> A multidimensional bounded or unbounded continuous action space</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">low</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span>
|
|
<span class="n">high</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">default_action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_abs_range</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">))</span>
|
|
|
|
<span class="c1"># default action</span>
|
|
<span class="k">if</span> <span class="n">default_action</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">isinf</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">))</span> <span class="ow">or</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">isinf</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">)):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span> <span class="o">+</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="n">default_action</span>
|
|
|
|
<span class="k">def</span> <span class="nf">clip_action_to_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionType</span><span class="p">:</span>
|
|
<span class="n">action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">action</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">action</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="DiscreteActionSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.DiscreteActionSpace">[docs]</a><span class="k">class</span> <span class="nc">DiscreteActionSpace</span><span class="p">(</span><span class="n">ActionSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> A discrete action space with action indices as actions</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_actions</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">default_action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">filtered_action_space</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">num_actions</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descriptions</span><span class="o">=</span><span class="n">descriptions</span><span class="p">)</span>
|
|
<span class="c1"># the number of actions is mapped to high</span>
|
|
|
|
<span class="c1"># default action</span>
|
|
<span class="k">if</span> <span class="n">default_action</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="n">default_action</span>
|
|
|
|
<span class="k">if</span> <span class="n">filtered_action_space</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">filtered_action_space</span> <span class="o">=</span> <span class="n">filtered_action_space</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">actions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">]:</span>
|
|
<span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
|
|
|
<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span> <span class="nf">sample_with_info</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionInfo</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">(),</span>
|
|
<span class="n">all_action_probabilities</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">),</span> <span class="mf">1.</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)))</span>
|
|
|
|
<span class="k">def</span> <span class="nf">get_description</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">)</span> <span class="o">==</span> <span class="nb">list</span> <span class="ow">and</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">action</span> <span class="o"><</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">[</span><span class="n">action</span><span class="p">]</span>
|
|
<span class="k">elif</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">)</span> <span class="o">==</span> <span class="nb">dict</span> <span class="ow">and</span> <span class="n">action</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">[</span><span class="n">action</span><span class="p">]</span>
|
|
<span class="k">elif</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">action</span> <span class="o"><</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The given action is outside of the action space"</span><span class="p">)</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="MultiSelectActionSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.MultiSelectActionSpace">[docs]</a><span class="k">class</span> <span class="nc">MultiSelectActionSpace</span><span class="p">(</span><span class="n">ActionSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> A discrete action space where multiple actions can be selected at once. The actions are encoded as multi-hot vectors</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_simultaneous_selected_actions</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">default_action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">allow_no_action_to_be_selected</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">descriptions</span><span class="o">=</span><span class="n">descriptions</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_simultaneous_selected_actions</span> <span class="o">=</span> <span class="n">max_simultaneous_selected_actions</span>
|
|
|
|
<span class="k">if</span> <span class="n">max_simultaneous_selected_actions</span> <span class="o">></span> <span class="n">size</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The maximum simultaneous selected actions can't be larger the max number of actions"</span><span class="p">)</span>
|
|
|
|
<span class="c1"># create all combinations of actions as a list of actions</span>
|
|
<span class="n">I</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">size</span><span class="p">)]</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">max_simultaneous_selected_actions</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_actions</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="k">if</span> <span class="n">allow_no_action_to_be_selected</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_actions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="p">))</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_actions</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">unique</span><span class="p">([</span><span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">product</span><span class="p">(</span><span class="o">*</span><span class="n">I</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)))</span>
|
|
|
|
<span class="c1"># default action</span>
|
|
<span class="k">if</span> <span class="n">default_action</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_actions</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="n">default_action</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">actions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">]:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_actions</span>
|
|
|
|
<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
|
|
<span class="c1"># samples a multi-hot vector</span>
|
|
<span class="k">return</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span> <span class="nf">sample_with_info</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionInfo</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">(),</span> <span class="n">all_action_probabilities</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">),</span> <span class="mf">1.</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">)))</span>
|
|
|
|
<span class="k">def</span> <span class="nf">get_description</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">action</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]))</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">action</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]))</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="ow">or</span> \
|
|
<span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">action</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]))</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_simultaneous_selected_actions</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The given action is not in the action space"</span><span class="p">)</span>
|
|
<span class="n">selected_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">action</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="n">description</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">[</span><span class="n">a</span><span class="p">]</span> <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">selected_actions</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">description</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">description</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'no-op'</span><span class="p">]</span>
|
|
<span class="k">return</span> <span class="s1">' + '</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">description</span><span class="p">)</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="CompoundActionSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.CompoundActionSpace">[docs]</a><span class="k">class</span> <span class="nc">CompoundActionSpace</span><span class="p">(</span><span class="n">ActionSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> An action space which consists of multiple sub-action spaces.</span>
|
|
<span class="sd"> For example, in Starcraft the agent should choose an action identifier from ~550 options (Discrete(550)),</span>
|
|
<span class="sd"> but it also needs to choose 13 different arguments for the selected action identifier, where each argument is</span>
|
|
<span class="sd"> by itself an action space. In Starcraft, the arguments are Discrete action spaces as well, but this is not mandatory.</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sub_spaces</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionSpace</span><span class="p">]):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sub_action_spaces</span> <span class="o">=</span> <span class="n">sub_spaces</span>
|
|
<span class="c1"># TODO: define the shape, low and high value in a better way</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span> <span class="nf">actions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">]:</span>
|
|
<span class="k">return</span> <span class="p">[</span><span class="n">action_space</span><span class="o">.</span><span class="n">actions</span> <span class="k">for</span> <span class="n">action_space</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">sub_action_spaces</span><span class="p">]</span>
|
|
|
|
<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionType</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="p">[</span><span class="n">action_space</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span> <span class="k">for</span> <span class="n">action_space</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">sub_action_spaces</span><span class="p">]</span>
|
|
|
|
<span class="k">def</span> <span class="nf">clip_action_to_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">actions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-></span> <span class="n">ActionType</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">actions</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">actions</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sub_action_spaces</span><span class="p">):</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The actions to be clipped must be a list with the same number of sub-actions as "</span>
|
|
<span class="s2">"defined in the compound action space."</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sub_action_spaces</span><span class="p">)):</span>
|
|
<span class="n">actions</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sub_action_spaces</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">clip_action_to_space</span><span class="p">(</span><span class="n">actions</span><span class="p">[</span><span class="n">idx</span><span class="p">])</span>
|
|
<span class="k">return</span> <span class="n">actions</span>
|
|
|
|
<span class="k">def</span> <span class="nf">get_description</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">actions</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
|
<span class="n">description</span> <span class="o">=</span> <span class="p">[</span><span class="n">action_space</span><span class="o">.</span><span class="n">get_description</span><span class="p">(</span><span class="n">action</span><span class="p">)</span> <span class="k">for</span> <span class="n">action_space</span><span class="p">,</span> <span class="n">action</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sub_action_spaces</span><span class="p">,</span> <span class="n">actions</span><span class="p">)]</span>
|
|
<span class="k">return</span> <span class="s1">' + '</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">description</span><span class="p">)</span></div>
|
|
|
|
|
|
<span class="sd">"""</span>
|
|
<span class="sd">Goals</span>
|
|
<span class="sd">"""</span>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">GoalToRewardConversion</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">goal_reaching_reward</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">goal_reaching_reward</span> <span class="o">=</span> <span class="n">goal_reaching_reward</span>
|
|
|
|
<span class="k">def</span> <span class="nf">convert_distance_to_reward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">distance</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">bool</span><span class="p">]:</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Given a distance from the goal, return a reward and a flag representing if the goal was reached</span>
|
|
|
|
<span class="sd"> :param distance: the distance from the goal</span>
|
|
<span class="sd"> :return:</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">ReachingGoal</span><span class="p">(</span><span class="n">GoalToRewardConversion</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> get a reward if the goal was reached and 0 otherwise</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">distance_from_goal_threshold</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">goal_reaching_reward</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">default_reward</span><span class="p">:</span> <span class="nb">float</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> :param distance_from_goal_threshold: consider getting to this distance from the goal the same as getting</span>
|
|
<span class="sd"> to the goal</span>
|
|
<span class="sd"> :param goal_reaching_reward: the reward the agent will get when reaching the goal</span>
|
|
<span class="sd"> :param default_reward: the reward the agent will get until it reaches the goal</span>
|
|
<span class="sd"> """</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">goal_reaching_reward</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">distance_from_goal_threshold</span> <span class="o">=</span> <span class="n">distance_from_goal_threshold</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">default_reward</span> <span class="o">=</span> <span class="n">default_reward</span>
|
|
|
|
<span class="k">def</span> <span class="nf">convert_distance_to_reward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">distance</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">bool</span><span class="p">]:</span>
|
|
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">distance</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">distance_from_goal_threshold</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">goal_reaching_reward</span><span class="p">,</span> <span class="kc">True</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_reward</span><span class="p">,</span> <span class="kc">False</span>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">InverseDistanceFromGoal</span><span class="p">(</span><span class="n">GoalToRewardConversion</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> get a reward inversely proportional to the distance from the goal</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">distance_from_goal_threshold</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">max_reward</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> :param distance_from_goal_threshold: consider getting to this distance from the goal the same as getting</span>
|
|
<span class="sd"> to the goal</span>
|
|
<span class="sd"> :param max_reward: the max reward the agent can get</span>
|
|
<span class="sd"> """</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">goal_reaching_reward</span><span class="o">=</span><span class="n">max_reward</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">distance_from_goal_threshold</span> <span class="o">=</span> <span class="n">distance_from_goal_threshold</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_reward</span> <span class="o">=</span> <span class="n">max_reward</span>
|
|
|
|
<span class="k">def</span> <span class="nf">convert_distance_to_reward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">distance</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">bool</span><span class="p">]:</span>
|
|
<span class="k">return</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_reward</span><span class="p">,</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">distance</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)),</span> <span class="n">distance</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">distance_from_goal_threshold</span>
|
|
|
|
|
|
<div class="viewcode-block" id="GoalsSpace"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.GoalsSpace">[docs]</a><span class="k">class</span> <span class="nc">GoalsSpace</span><span class="p">(</span><span class="n">VectorObservationSpace</span><span class="p">,</span> <span class="n">ActionSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> A multidimensional space with a goal type definition. It also behaves as an action space, so that hierarchical</span>
|
|
<span class="sd"> agents can use it as an output action space.</span>
|
|
<span class="sd"> The class acts as a wrapper to the target space. So after setting the target space, all the values of the class</span>
|
|
<span class="sd"> will match the values of the target space (the shape, low, high, etc.)</span>
|
|
<span class="sd"> """</span>
|
|
<div class="viewcode-block" id="GoalsSpace.DistanceMetric"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.GoalsSpace.DistanceMetric">[docs]</a> <span class="k">class</span> <span class="nc">DistanceMetric</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
|
|
<span class="n">Euclidean</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">Cosine</span> <span class="o">=</span> <span class="mi">1</span>
|
|
<span class="n">Manhattan</span> <span class="o">=</span> <span class="mi">2</span></div>
|
|
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">goal_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">reward_type</span><span class="p">:</span> <span class="n">GoalToRewardConversion</span><span class="p">,</span>
|
|
<span class="n">distance_metric</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">DistanceMetric</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> :param goal_name: the name of the observation space to use as the achieved goal.</span>
|
|
<span class="sd"> :param reward_type: the reward type to use for converting distances from goal to rewards</span>
|
|
<span class="sd"> :param distance_metric: the distance metric to use. could be either one of the distances in the</span>
|
|
<span class="sd"> DistanceMetric enum, or a custom function that gets two vectors as input and</span>
|
|
<span class="sd"> returns the distance between them</span>
|
|
<span class="sd"> """</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">goal_name</span> <span class="o">=</span> <span class="n">goal_name</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">distance_metric</span> <span class="o">=</span> <span class="n">distance_metric</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">reward_type</span> <span class="o">=</span> <span class="n">reward_type</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">target_space</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_abs_range</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="k">def</span> <span class="nf">set_target_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">target_space</span><span class="p">:</span> <span class="n">Space</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">target_space</span> <span class="o">=</span> <span class="n">target_space</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">target_space</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_space</span><span class="o">.</span><span class="n">low</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_space</span><span class="o">.</span><span class="n">high</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_abs_range</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">low</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">))</span>
|
|
|
|
<div class="viewcode-block" id="GoalsSpace.goal_from_state"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.GoalsSpace.goal_from_state">[docs]</a> <span class="k">def</span> <span class="nf">goal_from_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Given a state, extract an observation according to the goal_name</span>
|
|
|
|
<span class="sd"> :param state: a dictionary of observations</span>
|
|
<span class="sd"> :return: the observation corresponding to the goal_name</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">return</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">goal_name</span><span class="p">]</span></div>
|
|
|
|
<div class="viewcode-block" id="GoalsSpace.distance_from_goal"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.GoalsSpace.distance_from_goal">[docs]</a> <span class="k">def</span> <span class="nf">distance_from_goal</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">goal</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Given a state, check its distance from the goal</span>
|
|
|
|
<span class="sd"> :param goal: a numpy array representing the goal</span>
|
|
<span class="sd"> :param state: a dict representing the state</span>
|
|
<span class="sd"> :return: the distance from the goal</span>
|
|
<span class="sd"> """</span>
|
|
<span class="n">state_value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">goal_from_state</span><span class="p">(</span><span class="n">state</span><span class="p">)</span>
|
|
|
|
<span class="c1"># calculate distance</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">distance_metric</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">DistanceMetric</span><span class="o">.</span><span class="n">Cosine</span><span class="p">:</span>
|
|
<span class="n">dist</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">spatial</span><span class="o">.</span><span class="n">distance</span><span class="o">.</span><span class="n">cosine</span><span class="p">(</span><span class="n">goal</span><span class="p">,</span> <span class="n">state_value</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">distance_metric</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">DistanceMetric</span><span class="o">.</span><span class="n">Euclidean</span><span class="p">:</span>
|
|
<span class="n">dist</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">spatial</span><span class="o">.</span><span class="n">distance</span><span class="o">.</span><span class="n">euclidean</span><span class="p">(</span><span class="n">goal</span><span class="p">,</span> <span class="n">state_value</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">distance_metric</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">DistanceMetric</span><span class="o">.</span><span class="n">Manhattan</span><span class="p">:</span>
|
|
<span class="n">dist</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">spatial</span><span class="o">.</span><span class="n">distance</span><span class="o">.</span><span class="n">cityblock</span><span class="p">(</span><span class="n">goal</span><span class="p">,</span> <span class="n">state_value</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="n">callable</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distance_metric</span><span class="p">):</span>
|
|
<span class="n">dist</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">distance_metric</span><span class="p">(</span><span class="n">goal</span><span class="p">,</span> <span class="n">state_value</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The given distance metric for the goal is not valid."</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="n">dist</span></div>
|
|
|
|
<div class="viewcode-block" id="GoalsSpace.get_reward_for_goal_and_state"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.GoalsSpace.get_reward_for_goal_and_state">[docs]</a> <span class="k">def</span> <span class="nf">get_reward_for_goal_and_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">goal</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">bool</span><span class="p">]:</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Given a state, check if the goal was reached and return a reward accordingly</span>
|
|
|
|
<span class="sd"> :param goal: a numpy array representing the goal</span>
|
|
<span class="sd"> :param state: a dict representing the state</span>
|
|
<span class="sd"> :return: the reward for the current goal and state pair and a boolean representing if the goal was reached</span>
|
|
<span class="sd"> """</span>
|
|
<span class="n">dist</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">distance_from_goal</span><span class="p">(</span><span class="n">goal</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">reward_type</span><span class="o">.</span><span class="n">convert_distance_to_reward</span><span class="p">(</span><span class="n">dist</span><span class="p">)</span></div></div>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">AgentSelection</span><span class="p">(</span><span class="n">DiscreteActionSpace</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> An discrete action space which is bounded by the number of agents to select from</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_agents</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_agents</span><span class="p">)</span>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">SpacesDefinition</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> A container class that allows passing the definitions of all the spaces at once</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">state</span><span class="p">:</span> <span class="n">StateSpace</span><span class="p">,</span>
|
|
<span class="n">goal</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">ObservationSpace</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span>
|
|
<span class="n">action</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">,</span>
|
|
<span class="n">reward</span><span class="p">:</span> <span class="n">RewardSpace</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">state</span> <span class="o">=</span> <span class="n">state</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">goal</span> <span class="o">=</span> <span class="n">goal</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">action</span> <span class="o">=</span> <span class="n">action</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">reward</span> <span class="o">=</span> <span class="n">reward</span>
|
|
</pre></div>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
<footer>
|
|
|
|
|
|
<hr/>
|
|
|
|
<div role="contentinfo">
|
|
<p>
|
|
© Copyright 2018-2019, Intel AI Lab
|
|
|
|
</p>
|
|
</div>
|
|
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
|
|
|
</footer>
|
|
|
|
</div>
|
|
</div>
|
|
|
|
</section>
|
|
|
|
</div>
|
|
|
|
|
|
|
|
<script type="text/javascript">
|
|
jQuery(function () {
|
|
SphinxRtdTheme.Navigation.enable(true);
|
|
});
|
|
</script>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
</body>
|
|
</html> |