mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)
* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
This commit is contained in:
@@ -307,6 +307,11 @@
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">deploy</span><span class="p">()</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"redis"</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_address</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_address</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_port</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_port</span>
|
||||
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">deploy</span><span class="p">():</span>
|
||||
<span class="k">return</span> <span class="kc">False</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
@@ -329,6 +334,8 @@
|
||||
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--data_store_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
<span class="n">name</span> <span class="o">=</span> <span class="s2">"</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
|
||||
|
||||
<span class="c1"># TODO: instead of defining each container and template spec from scratch, loaded default</span>
|
||||
<span class="c1"># configuration and modify them as necessary depending on the store type</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
@@ -354,7 +361,7 @@
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"s3"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">image</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
|
||||
@@ -373,6 +380,34 @@
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"redis"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">image</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
|
||||
<span class="n">command</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span><span class="p">,</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span>
|
||||
<span class="n">image_pull_policy</span><span class="o">=</span><span class="s1">'Always'</span><span class="p">,</span>
|
||||
<span class="n">stdin</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">tty</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">resources</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
|
||||
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
|
||||
<span class="s2">"cpu"</span><span class="p">:</span> <span class="s2">"40"</span><span class="p">,</span>
|
||||
<span class="s2">"memory"</span><span class="p">:</span> <span class="s2">"4Gi"</span><span class="p">,</span>
|
||||
<span class="s2">"nvidia.com/gpu"</span><span class="p">:</span> <span class="s2">"1"</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">template</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
|
||||
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">'app'</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
|
||||
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
|
||||
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"unexpected store_type </span><span class="si">{}</span><span class="s2">. expected 's3', 'nfs', 'redis'"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span>
|
||||
<span class="p">))</span>
|
||||
|
||||
<span class="n">job_spec</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1JobSpec</span><span class="p">(</span>
|
||||
<span class="n">completions</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||||
@@ -404,12 +439,17 @@
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">worker_params</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="kc">False</span>
|
||||
|
||||
<span class="c1"># At this point, the memory backend and data store have been deployed and in the process,</span>
|
||||
<span class="c1"># these parameters have been updated to include things like the hostname and port the</span>
|
||||
<span class="c1"># service can be found at.</span>
|
||||
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--memory_backend_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">memory_backend_parameters</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--data_store_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--num_workers'</span><span class="p">,</span> <span class="s1">'</span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">worker_params</span><span class="o">.</span><span class="n">num_replicas</span><span class="p">)]</span>
|
||||
|
||||
<span class="n">name</span> <span class="o">=</span> <span class="s2">"</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">worker_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
|
||||
|
||||
<span class="c1"># TODO: instead of defining each container and template spec from scratch, loaded default</span>
|
||||
<span class="c1"># configuration and modify them as necessary depending on the store type</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
@@ -435,7 +475,7 @@
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"s3"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">image</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
|
||||
@@ -454,6 +494,32 @@
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"redis"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">image</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
|
||||
<span class="n">command</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">command</span><span class="p">,</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span>
|
||||
<span class="n">image_pull_policy</span><span class="o">=</span><span class="s1">'Always'</span><span class="p">,</span>
|
||||
<span class="n">stdin</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">tty</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">resources</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
|
||||
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
|
||||
<span class="s2">"cpu"</span><span class="p">:</span> <span class="s2">"8"</span><span class="p">,</span>
|
||||
<span class="s2">"memory"</span><span class="p">:</span> <span class="s2">"4Gi"</span><span class="p">,</span>
|
||||
<span class="c1"># "nvidia.com/gpu": "0",</span>
|
||||
<span class="p">}</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">template</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
|
||||
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">'app'</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
|
||||
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
|
||||
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'unexpected store type </span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span><span class="p">))</span>
|
||||
|
||||
<span class="n">job_spec</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1JobSpec</span><span class="p">(</span>
|
||||
<span class="n">completions</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">num_replicas</span><span class="p">,</span>
|
||||
|
||||
Reference in New Issue
Block a user