1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)

* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
This commit is contained in:
Gal Leibovich
2019-09-08 12:53:49 +03:00
committed by GitHub
parent fc50398544
commit 138ced23ba
46 changed files with 1193 additions and 51 deletions

View File

@@ -295,7 +295,9 @@
<span class="bp">self</span><span class="o">.</span><span class="n">optimization_epochs</span> <span class="o">=</span> <span class="mi">10</span>
<span class="bp">self</span><span class="o">.</span><span class="n">normalization_stats</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clipping_decay_schedule</span> <span class="o">=</span> <span class="n">ConstantSchedule</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">act_for_full_episodes</span> <span class="o">=</span> <span class="kc">True</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">act_for_full_episodes</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="o">=</span> <span class="kc">False</span></div>
<span class="k">class</span> <span class="nc">ClippedPPOAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
@@ -486,7 +488,9 @@
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">transitions</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">)</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="k">for</span> <span class="n">training_step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">):</span>
@@ -512,7 +516,9 @@
<span class="k">def</span> <span class="nf">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span>
<span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clipping_decay_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>