1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

more clear names for methods of Space (#181)

* rename Space.val_matches_space_definition -> contains; Space.is_point_in_space_shape -> valid_index
* rename valid_index -> is_valid_index
This commit is contained in:
Zach Dwiel
2019-01-14 15:02:53 -05:00
committed by GitHub
parent 0ccc333d77
commit cd812b0d25
19 changed files with 77 additions and 62 deletions

View File

@@ -459,7 +459,7 @@
<span class="sd"> :return: the environment response as returned in get_last_env_response</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">clip_action_to_space</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">val_matches_space_definition</span><span class="p">(</span><span class="n">action</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">contains</span><span class="p">(</span><span class="n">action</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The given action does not match the action space definition. &quot;</span>
<span class="s2">&quot;Action = </span><span class="si">{}</span><span class="s2">, action space definition = </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">action</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="p">))</span>

View File

@@ -222,7 +222,7 @@
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_actions</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The target actions were not set&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_actions</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">output_action_space</span><span class="o">.</span><span class="n">val_matches_space_definition</span><span class="p">(</span><span class="n">v</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">output_action_space</span><span class="o">.</span><span class="n">contains</span><span class="p">(</span><span class="n">v</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The values in the output actions (</span><span class="si">{}</span><span class="s2">) do not match the output action &quot;</span>
<span class="s2">&quot;space definition (</span><span class="si">{}</span><span class="s2">)&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">output_action_space</span><span class="p">))</span>

View File

@@ -251,8 +251,8 @@
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">crop_high</span> <span class="o">&gt;</span> <span class="n">input_observation_space</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="ow">or</span> \
<span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">crop_low</span> <span class="o">&gt;</span> <span class="n">input_observation_space</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The cropping values are outside of the observation space&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">input_observation_space</span><span class="o">.</span><span class="n">is_point_in_space_shape</span><span class="p">(</span><span class="n">crop_low</span><span class="p">)</span> <span class="ow">or</span> \
<span class="ow">not</span> <span class="n">input_observation_space</span><span class="o">.</span><span class="n">is_point_in_space_shape</span><span class="p">(</span><span class="n">crop_high</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">input_observation_space</span><span class="o">.</span><span class="n">is_valid_index</span><span class="p">(</span><span class="n">crop_low</span><span class="p">)</span> <span class="ow">or</span> \
<span class="ow">not</span> <span class="n">input_observation_space</span><span class="o">.</span><span class="n">is_valid_index</span><span class="p">(</span><span class="n">crop_high</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The cropping indices are outside of the observation space&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">observation</span><span class="p">:</span> <span class="n">ObservationType</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ObservationType</span><span class="p">:</span>

View File

@@ -259,7 +259,7 @@
<span class="s2">&quot;functionality&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">validate_input_observation_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_observation_space</span><span class="p">:</span> <span class="n">ObservationSpace</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stack</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">input_observation_space</span><span class="o">.</span><span class="n">val_matches_space_definition</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stack</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stack</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">input_observation_space</span><span class="o">.</span><span class="n">contains</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stack</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The given input observation space is different than the observations already stored in&quot;</span>
<span class="s2">&quot;the filters memory&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">input_observation_space</span><span class="o">.</span><span class="n">num_dimensions</span> <span class="o">&lt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stacking_axis</span><span class="p">:</span>

View File

@@ -297,7 +297,7 @@
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_high</span><span class="p">)</span> <span class="o">==</span> <span class="nb">int</span> <span class="ow">or</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_high</span><span class="p">)</span> <span class="o">==</span> <span class="nb">float</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_high</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">_high</span>
<div class="viewcode-block" id="Space.val_matches_space_definition"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.val_matches_space_definition">[docs]</a> <span class="k">def</span> <span class="nf">val_matches_space_definition</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<div class="viewcode-block" id="Space.contains"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.contains">[docs]</a> <span class="k">def</span> <span class="nf">contains</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Checks if the given value matches the space definition in terms of shape and values</span>
@@ -314,7 +314,7 @@
<span class="k">return</span> <span class="kc">False</span>
<span class="k">return</span> <span class="kc">True</span></div>
<div class="viewcode-block" id="Space.is_point_in_space_shape"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.is_point_in_space_shape">[docs]</a> <span class="k">def</span> <span class="nf">is_point_in_space_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">point</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<div class="viewcode-block" id="Space.is_valid_index"><a class="viewcode-back" href="../../components/spaces.html#rl_coach.spaces.Space.is_valid_index">[docs]</a> <span class="k">def</span> <span class="nf">is_valid_index</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">point</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Checks if a given multidimensional point is within the bounds of the shape of the space</span>

View File

@@ -223,8 +223,8 @@ or a single value defining the general highest values</li>
</tbody>
</table>
<dl class="method">
<dt id="rl_coach.spaces.Space.is_point_in_space_shape">
<code class="descname">is_point_in_space_shape</code><span class="sig-paren">(</span><em>point: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.is_point_in_space_shape"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.is_point_in_space_shape" title="Permalink to this definition"></a></dt>
<dt id="rl_coach.spaces.Space.is_valid_index">
<code class="descname">is_valid_index</code><span class="sig-paren">(</span><em>point: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.is_valid_index"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.is_valid_index" title="Permalink to this definition"></a></dt>
<dd><p>Checks if a given multidimensional point is within the bounds of the shape of the space</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
@@ -254,8 +254,8 @@ bounds are defined</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.spaces.Space.val_matches_space_definition">
<code class="descname">val_matches_space_definition</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.val_matches_space_definition"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.val_matches_space_definition" title="Permalink to this definition"></a></dt>
<dt id="rl_coach.spaces.Space.contains">
<code class="descname">contains</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.contains"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.contains" title="Permalink to this definition"></a></dt>
<dd><p>Checks if the given value matches the space definition in terms of shape and values</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
@@ -278,8 +278,8 @@ bounds are defined</p>
<dt id="rl_coach.spaces.ObservationSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">ObservationSpace</code><span class="sig-paren">(</span><em>shape: Union[int, numpy.ndarray], low: Union[None, int, float, numpy.ndarray] = -inf, high: Union[None, int, float, numpy.ndarray] = inf</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#ObservationSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ObservationSpace" title="Permalink to this definition"></a></dt>
<dd><dl class="method">
<dt id="rl_coach.spaces.ObservationSpace.is_point_in_space_shape">
<code class="descname">is_point_in_space_shape</code><span class="sig-paren">(</span><em>point: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.is_point_in_space_shape" title="Permalink to this definition"></a></dt>
<dt id="rl_coach.spaces.ObservationSpace.is_valid_index">
<code class="descname">is_valid_index</code><span class="sig-paren">(</span><em>point: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.is_valid_index" title="Permalink to this definition"></a></dt>
<dd><p>Checks if a given multidimensional point is within the bounds of the shape of the space</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
@@ -309,8 +309,8 @@ bounds are defined</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.spaces.ObservationSpace.val_matches_space_definition">
<code class="descname">val_matches_space_definition</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.val_matches_space_definition" title="Permalink to this definition"></a></dt>
<dt id="rl_coach.spaces.ObservationSpace.contains">
<code class="descname">contains</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.contains" title="Permalink to this definition"></a></dt>
<dd><p>Checks if the given value matches the space definition in terms of shape and values</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
@@ -379,8 +379,8 @@ represent a RGB image, or a grayscale image.</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.spaces.ActionSpace.is_point_in_space_shape">
<code class="descname">is_point_in_space_shape</code><span class="sig-paren">(</span><em>point: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ActionSpace.is_point_in_space_shape" title="Permalink to this definition"></a></dt>
<dt id="rl_coach.spaces.ActionSpace.is_valid_index">
<code class="descname">is_valid_index</code><span class="sig-paren">(</span><em>point: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ActionSpace.is_valid_index" title="Permalink to this definition"></a></dt>
<dd><p>Checks if a given multidimensional point is within the bounds of the shape of the space</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
@@ -424,8 +424,8 @@ bounds are defined</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.spaces.ActionSpace.val_matches_space_definition">
<code class="descname">val_matches_space_definition</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ActionSpace.val_matches_space_definition" title="Permalink to this definition"></a></dt>
<dt id="rl_coach.spaces.ActionSpace.contains">
<code class="descname">contains</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ActionSpace.contains" title="Permalink to this definition"></a></dt>
<dd><p>Checks if the given value matches the space definition in terms of shape and values</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
@@ -598,8 +598,8 @@ returns the distance between them</li>
</dd></dl>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.is_point_in_space_shape">
<code class="descname">is_point_in_space_shape</code><span class="sig-paren">(</span><em>point: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.is_point_in_space_shape" title="Permalink to this definition"></a></dt>
<dt id="rl_coach.spaces.GoalsSpace.is_valid_index">
<code class="descname">is_valid_index</code><span class="sig-paren">(</span><em>point: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.is_valid_index" title="Permalink to this definition"></a></dt>
<dd><p>Checks if a given multidimensional point is within the bounds of the shape of the space</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
@@ -643,8 +643,8 @@ bounds are defined</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.val_matches_space_definition">
<code class="descname">val_matches_space_definition</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.val_matches_space_definition" title="Permalink to this definition"></a></dt>
<dt id="rl_coach.spaces.GoalsSpace.contains">
<code class="descname">contains</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.contains" title="Permalink to this definition"></a></dt>
<dd><p>Checks if the given value matches the space definition in terms of shape and values</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />

View File

@@ -508,14 +508,14 @@
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Episode.is_empty">is_empty() (rl_coach.core_types.Episode method)</a>
</li>
<li><a href="components/spaces.html#rl_coach.spaces.ActionSpace.is_point_in_space_shape">is_point_in_space_shape() (rl_coach.spaces.ActionSpace method)</a>
<li><a href="components/spaces.html#rl_coach.spaces.ActionSpace.is_valid_index">is_valid_index() (rl_coach.spaces.ActionSpace method)</a>
<ul>
<li><a href="components/spaces.html#rl_coach.spaces.GoalsSpace.is_point_in_space_shape">(rl_coach.spaces.GoalsSpace method)</a>
<li><a href="components/spaces.html#rl_coach.spaces.GoalsSpace.is_valid_index">(rl_coach.spaces.GoalsSpace method)</a>
</li>
<li><a href="components/spaces.html#rl_coach.spaces.ObservationSpace.is_point_in_space_shape">(rl_coach.spaces.ObservationSpace method)</a>
<li><a href="components/spaces.html#rl_coach.spaces.ObservationSpace.is_valid_index">(rl_coach.spaces.ObservationSpace method)</a>
</li>
<li><a href="components/spaces.html#rl_coach.spaces.Space.is_point_in_space_shape">(rl_coach.spaces.Space method)</a>
<li><a href="components/spaces.html#rl_coach.spaces.Space.is_valid_index">(rl_coach.spaces.Space method)</a>
</li>
</ul></li>
</ul></td>
@@ -912,14 +912,14 @@
<h2 id="V">V</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/spaces.html#rl_coach.spaces.ActionSpace.val_matches_space_definition">val_matches_space_definition() (rl_coach.spaces.ActionSpace method)</a>
<li><a href="components/spaces.html#rl_coach.spaces.ActionSpace.contains">contains() (rl_coach.spaces.ActionSpace method)</a>
<ul>
<li><a href="components/spaces.html#rl_coach.spaces.GoalsSpace.val_matches_space_definition">(rl_coach.spaces.GoalsSpace method)</a>
<li><a href="components/spaces.html#rl_coach.spaces.GoalsSpace.contains">(rl_coach.spaces.GoalsSpace method)</a>
</li>
<li><a href="components/spaces.html#rl_coach.spaces.ObservationSpace.val_matches_space_definition">(rl_coach.spaces.ObservationSpace method)</a>
<li><a href="components/spaces.html#rl_coach.spaces.ObservationSpace.contains">(rl_coach.spaces.ObservationSpace method)</a>
</li>
<li><a href="components/spaces.html#rl_coach.spaces.Space.val_matches_space_definition">(rl_coach.spaces.Space method)</a>
<li><a href="components/spaces.html#rl_coach.spaces.Space.contains">(rl_coach.spaces.Space method)</a>
</li>
</ul></li>
</ul></td>

File diff suppressed because one or more lines are too long

View File

@@ -280,7 +280,7 @@ class Environment(EnvironmentInterface):
:return: the environment response as returned in get_last_env_response
"""
action = self.action_space.clip_action_to_space(action)
if self.action_space and not self.action_space.val_matches_space_definition(action):
if self.action_space and not self.action_space.contains(action):
raise ValueError("The given action does not match the action space definition. "
"Action = {}, action space definition = {}".format(action, self.action_space))

View File

@@ -47,7 +47,7 @@ class ActionFilter(Filter):
:param action: an action to validate
:return: None
"""
if not self.output_action_space.val_matches_space_definition(action):
if not self.output_action_space.contains(action):
raise ValueError("The given action ({}) does not match the action space ({})"
.format(action, self.output_action_space))

View File

@@ -42,7 +42,7 @@ class PartialDiscreteActionSpaceMap(ActionFilter):
if not self.target_actions:
raise ValueError("The target actions were not set")
for v in self.target_actions:
if not output_action_space.val_matches_space_definition(v):
if not output_action_space.contains(v):
raise ValueError("The values in the output actions ({}) do not match the output action "
"space definition ({})".format(v, output_action_space))

View File

@@ -71,8 +71,8 @@ class ObservationCropFilter(ObservationFilter):
if np.any(crop_high > input_observation_space.shape) or \
np.any(crop_low > input_observation_space.shape):
raise ValueError("The cropping values are outside of the observation space")
if not input_observation_space.is_point_in_space_shape(crop_low) or \
not input_observation_space.is_point_in_space_shape(crop_high - 1):
if not input_observation_space.is_valid_index(crop_low) or \
not input_observation_space.is_valid_index(crop_high - 1):
raise ValueError("The cropping indices are outside of the observation space")
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:

View File

@@ -79,7 +79,7 @@ class ObservationStackingFilter(ObservationFilter):
"functionality")
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
if len(self.stack) > 0 and not input_observation_space.val_matches_space_definition(self.stack[-1]):
if len(self.stack) > 0 and not input_observation_space.contains(self.stack[-1]):
raise ValueError("The given input observation space is different than the observations already stored in"
"the filters memory")
if input_observation_space.num_dimensions <= self.stacking_axis:

View File

@@ -117,9 +117,10 @@ class Space(object):
if type(self._high) == int or type(self._high) == float:
self._high = np.ones(self.shape)*self._high
def val_matches_space_definition(self, val: Union[int, float, np.ndarray]) -> bool:
def contains(self, val: Union[int, float, np.ndarray]) -> bool:
"""
Checks if the given value matches the space definition in terms of shape and values
Checks if value is contained by this space. The shape must match and
all of the values must be within the low and high bounds.
:param val: a value to check
:return: True / False depending on if the val matches the space definition
@@ -134,16 +135,16 @@ class Space(object):
return False
return True
def is_point_in_space_shape(self, point: np.ndarray) -> bool:
def is_valid_index(self, index: np.ndarray) -> bool:
"""
Checks if a given multidimensional point is within the bounds of the shape of the space
Checks if a given multidimensional index is within the bounds of the shape of the space
:param point: a multidimensional point
:return: True if the point is within the shape of the space. False otherwise
:param index: a multidimensional index
:return: True if the index is within the shape of the space. False otherwise
"""
if len(point) != self.num_dimensions:
if len(index) != self.num_dimensions:
return False
if np.any(point < np.zeros(self.num_dimensions)) or np.any(point >= self.shape):
if np.any(index < np.zeros(self.num_dimensions)) or np.any(index >= self.shape):
return False
return True
@@ -160,6 +161,20 @@ class Space(object):
else:
return np.random.uniform(self.low, self.high, self.shape)
def val_matches_space_definition(self, val: Union[int, float, np.ndarray]) -> bool:
screen.warning(
"Space.val_matches_space_definition will be deprecated soon. Use "
"contains instead."
)
return self.contains(val)
def is_point_in_space_shape(self, point: np.ndarray) -> bool:
screen.warning(
"Space.is_point_in_space_shape will be deprecated soon. Use "
"is_valid_index instead."
)
return self.is_valid_index(point)
class RewardSpace(Space):
def __init__(self, shape: Union[int, np.ndarray], low: Union[None, int, float, np.ndarray]=-np.inf,

View File

@@ -32,7 +32,7 @@ def test_filter():
result = filter.filter(action)
assert np.all(result == np.array([[41.5, 0], [83., 41.5]]))
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)
# force int bins
filter = AttentionDiscretization(2, force_int_bins=True)

View File

@@ -26,7 +26,7 @@ def test_filter():
result = filter.filter(action)
assert result == [7.5]
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)
# 2 dimensional box
filter = BoxDiscretization(3)
@@ -42,4 +42,4 @@ def test_filter():
result = filter.filter(action)
assert result == [5., 15.]
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)

View File

@@ -23,5 +23,5 @@ def test_filter():
action = np.array([2])
result = filter.filter(action)
assert result == np.array([12])
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)

View File

@@ -25,5 +25,5 @@ def test_filter():
action = np.array([12])
result = filter.filter(action)
assert result == np.array([11])
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)

View File

@@ -132,18 +132,18 @@ def test_agent_selection():
def test_observation_space():
observation_space = ObservationSpace(np.array([1, 10]), -10, 10)
# testing that val_matches_space_definition works
assert observation_space.val_matches_space_definition(np.ones([1, 10]))
assert not observation_space.val_matches_space_definition(np.ones([2, 10]))
assert not observation_space.val_matches_space_definition(np.ones([1, 10]) * 100)
assert not observation_space.val_matches_space_definition(np.ones([1, 1, 10]))
# testing that contains works
assert observation_space.contains(np.ones([1, 10]))
assert not observation_space.contains(np.ones([2, 10]))
assert not observation_space.contains(np.ones([1, 10]) * 100)
assert not observation_space.contains(np.ones([1, 1, 10]))
# is_point_in_space_shape
assert observation_space.is_point_in_space_shape(np.array([0, 9]))
assert observation_space.is_point_in_space_shape(np.array([0, 0]))
assert not observation_space.is_point_in_space_shape(np.array([1, 8]))
assert not observation_space.is_point_in_space_shape(np.array([0, 10]))
assert not observation_space.is_point_in_space_shape(np.array([-1, 6]))
# is_valid_index
assert observation_space.is_valid_index(np.array([0, 9]))
assert observation_space.is_valid_index(np.array([0, 0]))
assert not observation_space.is_valid_index(np.array([1, 8]))
assert not observation_space.is_valid_index(np.array([0, 10]))
assert not observation_space.is_valid_index(np.array([-1, 6]))
@pytest.mark.unit_test