mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
This commit is contained in:
@@ -313,7 +313,7 @@
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="n">bootstrapped_return_from_old_policy</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
<span class="c1"># only update the action that we have actually done in this transition</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">bootstrapped_return_from_old_policy</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># set the gradients to fetch for the DND update</span>
|
||||
@@ -342,7 +342,7 @@
|
||||
<span class="n">embedding</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">),</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">state_embedding</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
|
||||
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">act</span><span class="p">()</span>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user