1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)

* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
This commit is contained in:
Gal Leibovich
2019-09-08 12:53:49 +03:00
committed by GitHub
parent fc50398544
commit 138ced23ba
46 changed files with 1193 additions and 51 deletions

View File

@@ -307,6 +307,11 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">deploy</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;redis&quot;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_address</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_address</span>
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_port</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_port</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">deploy</span><span class="p">():</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;nfs&quot;</span><span class="p">:</span>
@@ -329,6 +334,8 @@
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--data_store_params&#39;</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
<span class="n">name</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
<span class="c1"># TODO: instead of defining each container and template spec from scratch, loaded default</span>
<span class="c1"># configuration and modify them as necessary depending on the store type</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;nfs&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
@@ -354,7 +361,7 @@
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;s3&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
<span class="n">image</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
@@ -373,6 +380,34 @@
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;redis&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
<span class="n">image</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
<span class="n">command</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span>
<span class="n">image_pull_policy</span><span class="o">=</span><span class="s1">&#39;Always&#39;</span><span class="p">,</span>
<span class="n">stdin</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">tty</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">resources</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
<span class="s2">&quot;cpu&quot;</span><span class="p">:</span> <span class="s2">&quot;40&quot;</span><span class="p">,</span>
<span class="s2">&quot;memory&quot;</span><span class="p">:</span> <span class="s2">&quot;4Gi&quot;</span><span class="p">,</span>
<span class="s2">&quot;nvidia.com/gpu&quot;</span><span class="p">:</span> <span class="s2">&quot;1&quot;</span><span class="p">,</span>
<span class="p">}</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="n">template</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;app&#39;</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;unexpected store_type </span><span class="si">{}</span><span class="s2">. expected &#39;s3&#39;, &#39;nfs&#39;, &#39;redis&#39;&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span>
<span class="p">))</span>
<span class="n">job_spec</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1JobSpec</span><span class="p">(</span>
<span class="n">completions</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
@@ -404,12 +439,17 @@
<span class="k">if</span> <span class="ow">not</span> <span class="n">worker_params</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="c1"># At this point, the memory backend and data store have been deployed and in the process,</span>
<span class="c1"># these parameters have been updated to include things like the hostname and port the</span>
<span class="c1"># service can be found at.</span>
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--memory_backend_params&#39;</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">memory_backend_parameters</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--data_store_params&#39;</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--num_workers&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">worker_params</span><span class="o">.</span><span class="n">num_replicas</span><span class="p">)]</span>
<span class="n">name</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">worker_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
<span class="c1"># TODO: instead of defining each container and template spec from scratch, loaded default</span>
<span class="c1"># configuration and modify them as necessary depending on the store type</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;nfs&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
@@ -435,7 +475,7 @@
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;s3&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
<span class="n">image</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
@@ -454,6 +494,32 @@
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;redis&quot;</span><span class="p">:</span>
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
<span class="n">image</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
<span class="n">command</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">command</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span>
<span class="n">image_pull_policy</span><span class="o">=</span><span class="s1">&#39;Always&#39;</span><span class="p">,</span>
<span class="n">stdin</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">tty</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">resources</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
<span class="s2">&quot;cpu&quot;</span><span class="p">:</span> <span class="s2">&quot;8&quot;</span><span class="p">,</span>
<span class="s2">&quot;memory&quot;</span><span class="p">:</span> <span class="s2">&quot;4Gi&quot;</span><span class="p">,</span>
<span class="c1"># &quot;nvidia.com/gpu&quot;: &quot;0&quot;,</span>
<span class="p">}</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="n">template</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;app&#39;</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;unexpected store type </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span><span class="p">))</span>
<span class="n">job_spec</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1JobSpec</span><span class="p">(</span>
<span class="n">completions</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">num_replicas</span><span class="p">,</span>