mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
TD3 (#338)
This commit is contained in:
@@ -199,7 +199,7 @@
|
||||
|
||||
<span class="kn">import</span> <span class="nn">os</span>
|
||||
<span class="kn">import</span> <span class="nn">pickle</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
@@ -223,6 +223,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DNDQHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="NECAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/nec.html#rl_coach.agents.nec_agent.NECAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">NECAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
@@ -349,11 +350,25 @@
|
||||
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">act</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">additional_outputs</span><span class="p">:</span> <span class="n">List</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||||
<span class="c1"># we need to store the state embeddings regardless if the action is random or not</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction_and_update_embeddings</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states_and_softmax_probabilities</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="c1"># get the actions q values and the state embedding</span>
|
||||
<span class="n">embedding</span><span class="p">,</span> <span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">),</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">state_embedding</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">output</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">]</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
|
||||
<span class="c1"># store the state embedding for inserting it to the DND later</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">return</span> <span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction_and_update_embeddings</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
|
||||
<span class="c1"># get the actions q values and the state embedding</span>
|
||||
<span class="n">embedding</span><span class="p">,</span> <span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">),</span>
|
||||
@@ -362,7 +377,7 @@
|
||||
<span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
|
||||
<span class="c1"># store the state embedding for inserting it to the DND later</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">return</span> <span class="n">actions_q_values</span>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user