mirror of
https://github.com/gryf/coach.git
synced 2026-02-20 16:55:48 +01:00
SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
This commit is contained in:
@@ -268,11 +268,14 @@
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="c1"># add Q value samples for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_q_values</span><span class="p">(</span><span class="n">current_quantiles</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># get the optimal actions to take for the next states</span>
|
||||
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_q_values</span><span class="p">(</span><span class="n">next_state_quantiles</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate the Bellman update</span>
|
||||
<span class="n">batch_idx</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">))</span>
|
||||
<span class="n">batch_idx</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||||
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">(</span><span class="kc">True</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> \
|
||||
<span class="o">*</span> <span class="n">next_state_quantiles</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">target_actions</span><span class="p">]</span>
|
||||
@@ -283,9 +286,9 @@
|
||||
<span class="c1"># calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order</span>
|
||||
<span class="n">cumulative_probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span><span class="p">)</span> <span class="c1"># tau_i</span>
|
||||
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="mf">0.5</span><span class="o">*</span><span class="p">(</span><span class="n">cumulative_probabilities</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="n">cumulative_probabilities</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># tau^hat_i</span>
|
||||
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">quantile_midpoints</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">quantile_midpoints</span><span class="p">,</span> <span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="n">sorted_quantiles</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">current_quantiles</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()])</span>
|
||||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="n">quantile_midpoints</span><span class="p">[</span><span class="n">idx</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">quantile_midpoints</span><span class="p">[</span><span class="n">idx</span><span class="p">,</span> <span class="n">sorted_quantiles</span><span class="p">[</span><span class="n">idx</span><span class="p">]]</span>
|
||||
|
||||
<span class="c1"># train</span>
|
||||
|
||||
Reference in New Issue
Block a user