1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

SAC algorithm (#282)

* SAC algorithm

* SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train.
gym_environment - fixing an error in access to gym.spaces

* Soft Actor Critic - code cleanup

* code cleanup

* V-head initialization fix

* SAC benchmarks

* SAC Documentation

* typo fix

* documentation fixes

* documentation and version update

* README typo
This commit is contained in:
guyk1971
2019-05-01 18:37:49 +03:00
committed by shadiendrawis
parent 33dc29ee99
commit 74db141d5e
92 changed files with 2812 additions and 402 deletions

View File

@@ -240,9 +240,12 @@
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">algorithm</span> <span class="o">=</span> <span class="n">RainbowDQNAlgorithmParameters</span><span class="p">()</span>
<span class="c1"># ParameterNoiseParameters is changing the network wrapper parameters. This line needs to be done first.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;main&quot;</span><span class="p">:</span> <span class="n">RainbowDQNNetworkParameters</span><span class="p">()}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span> <span class="o">=</span> <span class="n">ParameterNoiseParameters</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="o">=</span> <span class="n">PrioritizedExperienceReplayParameters</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;main&quot;</span><span class="p">:</span> <span class="n">RainbowDQNNetworkParameters</span><span class="p">()}</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
@@ -275,11 +278,14 @@
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
<span class="p">])</span>
<span class="c1"># add Q value samples for logging</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">))</span>
<span class="c1"># only update the action that we have actually done in this transition (using the Double-DQN selected actions)</span>
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">ddqn_selected_actions</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
<span class="c1"># we use batch.info(&#39;should_bootstrap_next_state&#39;) instead of (1 - batch.game_overs()) since with n-step,</span>
<span class="c1"># we will not bootstrap for the last n-step transitions in the episode</span>