1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)

* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
This commit is contained in:
Gal Leibovich
2019-09-08 12:53:49 +03:00
committed by GitHub
parent fc50398544
commit 138ced23ba
46 changed files with 1193 additions and 51 deletions

View File

@@ -756,6 +756,9 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">EpisodicExperienceReplay</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">override_episode_rewards_with_the_last_transition_reward</span><span class="p">:</span>
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
<span class="n">t</span><span class="o">.</span><span class="n">reward</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;store_episode&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">store_transitions_only_when_episodes_are_terminated</span><span class="p">:</span>
<span class="k">for</span> <span class="n">transition</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
@@ -910,7 +913,8 @@
<span class="c1"># update counters</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># if the batch returned empty then there are not enough samples in the replay buffer -&gt; skip</span>
<span class="c1"># training step</span>
@@ -1020,7 +1024,8 @@
<span class="c1"># informed action</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># before choosing an action, first use the pre_network_filter to filter out the current state</span>
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="n">update_filter_internal_state</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
@@ -1048,6 +1053,10 @@
<span class="sd"> :return: The filtered state</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># TODO actually we only want to run the observation filters. No point in running the reward filters as the</span>
<span class="c1"># filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters&#39; internal</span>
<span class="c1"># state).</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span>
<span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_filter_internal_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span></div>
@@ -1177,7 +1186,7 @@
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent</span>
<span class="sd"> has another master agent that is controlling it. In such cases, the master agent can define the goals for the</span>
<span class="sd"> slave agent, define it&#39;s observation, possible actions, etc. The directive type is defined by the agent</span>
<span class="sd"> slave agent, define its observation, possible actions, etc. The directive type is defined by the agent</span>
<span class="sd"> in-action-space.</span>
<span class="sd"> :param action: The action that should be set as the directive</span>