1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

SAC algorithm (#282)

* SAC algorithm

* SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train.
gym_environment - fixing an error in access to gym.spaces

* Soft Actor Critic - code cleanup

* code cleanup

* V-head initialization fix

* SAC benchmarks

* SAC Documentation

* typo fix

* documentation fixes

* documentation and version update

* README typo
This commit is contained in:
guyk1971
2019-05-01 18:37:49 +03:00
committed by shadiendrawis
parent 33dc29ee99
commit 74db141d5e
92 changed files with 2812 additions and 402 deletions

View File

@@ -480,7 +480,6 @@
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="c1"># frame skip and max between consecutive frames</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_robotics_env</span> <span class="o">=</span> <span class="s1">&#39;robotics&#39;</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_mujoco_env</span> <span class="o">=</span> <span class="s1">&#39;mujoco&#39;</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_roboschool_env</span> <span class="o">=</span> <span class="s1">&#39;roboschool&#39;</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_atari_env</span> <span class="o">=</span> <span class="s1">&#39;Atari&#39;</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
@@ -501,7 +500,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">state_space</span> <span class="o">=</span> <span class="n">StateSpace</span><span class="p">({})</span>
<span class="c1"># observations</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">,</span> <span class="n">gym</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">dict_space</span><span class="o">.</span><span class="n">Dict</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">,</span> <span class="n">gym</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">dict</span><span class="o">.</span><span class="n">Dict</span><span class="p">):</span>
<span class="n">state_space</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">}</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">state_space</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="o">.</span><span class="n">spaces</span>
@@ -665,38 +664,11 @@
<span class="c1"># initialize the number of lives</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_update_ale_lives</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">_set_mujoco_camera</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">camera_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> This function can be used to set the camera for rendering the mujoco simulator</span>
<span class="sd"> :param camera_idx: The index of the camera to use. Should be defined in the model</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span><span class="o">.</span><span class="n">cam</span><span class="o">.</span><span class="n">fixedcamid</span> <span class="o">!=</span> <span class="n">camera_idx</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">mujoco_py.generated</span> <span class="k">import</span> <span class="n">const</span>
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span><span class="o">.</span><span class="n">cam</span><span class="o">.</span><span class="n">type</span> <span class="o">=</span> <span class="n">const</span><span class="o">.</span><span class="n">CAMERA_FIXED</span>
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span><span class="o">.</span><span class="n">cam</span><span class="o">.</span><span class="n">fixedcamid</span> <span class="o">=</span> <span class="n">camera_idx</span>
<span class="k">def</span> <span class="nf">_get_robotics_image</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">()</span>
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">_get_viewer</span><span class="p">()</span><span class="o">.</span><span class="n">read_pixels</span><span class="p">(</span><span class="mi">1600</span><span class="p">,</span> <span class="mi">900</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="kc">False</span><span class="p">)[::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">misc</span><span class="o">.</span><span class="n">imresize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="p">(</span><span class="mi">270</span><span class="p">,</span> <span class="mi">480</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="k">return</span> <span class="n">image</span>
<span class="k">def</span> <span class="nf">_render</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">&#39;human&#39;</span><span class="p">)</span>
<span class="c1"># required for setting up a fixed camera for mujoco</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_mujoco_env</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_roboschool_env</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set_mujoco_camera</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_rendered_image</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_robotics_env</span><span class="p">:</span>
<span class="c1"># necessary for fetch since the rendered image is cropped to an irrelevant part of the simulator</span>
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_robotics_image</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">&#39;rgb_array&#39;</span><span class="p">)</span>
<span class="c1"># required for setting up a fixed camera for mujoco</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_mujoco_env</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_roboschool_env</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set_mujoco_camera</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">&#39;rgb_array&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">image</span>
<span class="k">def</span> <span class="nf">get_target_success_rate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>