mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
TD3 (#338)
This commit is contained in:
@@ -197,7 +197,7 @@
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
@@ -207,7 +207,8 @@
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.prioritized_experience_replay</span> <span class="k">import</span> <span class="n">PrioritizedExperienceReplay</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span>
|
||||
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">deepcopy</span>
|
||||
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">deepcopy</span><span class="p">,</span> <span class="n">copy</span>
|
||||
|
||||
|
||||
<span class="c1">## This is an abstract agent - there is no learn_from_batch method ##</span>
|
||||
|
||||
@@ -218,6 +219,12 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"Q"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
<span class="c1"># currently we use softmax action probabilities only in batch-rl,</span>
|
||||
<span class="c1"># but we might want to extend this later at some point.</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span> <span class="o">=</span> \
|
||||
<span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">],</span> <span class="s1">'should_get_softmax_probabilities'</span><span class="p">)</span> <span class="ow">and</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">init_environment_dependent_modules</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_environment_dependent_modules</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">):</span>
|
||||
@@ -228,12 +235,21 @@
|
||||
|
||||
<span class="c1"># Algorithms for which q_values are calculated from predictions will override this function</span>
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">actions_q_values</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states_and_softmax_probabilities</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
|
||||
<span class="n">outputs</span> <span class="o">=</span> <span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">outputs</span><span class="p">)</span>
|
||||
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">)</span>
|
||||
|
||||
<span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">),</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span>
|
||||
@@ -255,10 +271,19 @@
|
||||
<span class="p">)</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">policy</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">))</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_q_values_for_states</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span><span class="p">:</span>
|
||||
<span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">get_all_q_values_for_states_and_softmax_probabilities</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_q_values_for_states</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># choose action according to the exploration policy and the current phase (evaluating or training the agent)</span>
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
|
||||
<span class="n">action</span><span class="p">,</span> <span class="n">action_probabilities</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span> <span class="ow">and</span> <span class="n">softmax_probabilities</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="c1"># override the exploration policy's generated probabilities when an action was taken</span>
|
||||
<span class="c1"># with the agent's actual policy</span>
|
||||
<span class="n">action_probabilities</span> <span class="o">=</span> <span class="n">softmax_probabilities</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_validate_action</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="p">,</span> <span class="n">action</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">actions_q_values</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
@@ -270,15 +295,18 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
|
||||
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
<span class="n">action_probabilities</span> <span class="o">=</span> <span class="n">action_probabilities</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">q_value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q_value</span><span class="p">)</span>
|
||||
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span>
|
||||
<span class="n">action_value</span><span class="o">=</span><span class="n">actions_q_values</span><span class="p">[</span><span class="n">action</span><span class="p">],</span>
|
||||
<span class="n">max_action_value</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">))</span>
|
||||
<span class="n">max_action_value</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">),</span>
|
||||
<span class="n">all_action_probabilities</span><span class="o">=</span><span class="n">action_probabilities</span><span class="p">)</span>
|
||||
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">)</span>
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span> <span class="n">all_action_probabilities</span><span class="o">=</span><span class="n">action_probabilities</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">action_info</span>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user