1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)

* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
This commit is contained in:
Gal Leibovich
2019-09-08 12:53:49 +03:00
committed by GitHub
parent fc50398544
commit 138ced23ba
46 changed files with 1193 additions and 51 deletions

View File

@@ -200,15 +200,17 @@
<span class="kn">import</span> <span class="nn">uuid</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStore</span><span class="p">,</span> <span class="n">DataStoreParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStoreParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.checkpoint_data_store</span> <span class="k">import</span> <span class="n">CheckpointDataStore</span>
<span class="k">class</span> <span class="nc">NFSDataStoreParameters</span><span class="p">(</span><span class="n">DataStoreParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ds_params</span><span class="p">,</span> <span class="n">deployed</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">server</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">path</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ds_params</span><span class="p">,</span> <span class="n">deployed</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">server</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="o">=</span><span class="s2">&quot;&quot;</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">ds_params</span><span class="o">.</span><span class="n">store_type</span><span class="p">,</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_type</span><span class="p">,</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">namespace</span> <span class="o">=</span> <span class="s2">&quot;default&quot;</span>
<span class="k">if</span> <span class="s2">&quot;namespace&quot;</span> <span class="ow">in</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">namespace</span> <span class="o">=</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">[</span><span class="s2">&quot;namespace&quot;</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_dir</span> <span class="o">=</span> <span class="n">checkpoint_dir</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pvc_name</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pv_name</span> <span class="o">=</span> <span class="kc">None</span>
@@ -221,7 +223,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">path</span> <span class="o">=</span> <span class="n">path</span>
<div class="viewcode-block" id="NFSDataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.nfs_data_store.NFSDataStore">[docs]</a><span class="k">class</span> <span class="nc">NFSDataStore</span><span class="p">(</span><span class="n">DataStore</span><span class="p">):</span>
<div class="viewcode-block" id="NFSDataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.nfs_data_store.NFSDataStore">[docs]</a><span class="k">class</span> <span class="nc">NFSDataStore</span><span class="p">(</span><span class="n">CheckpointDataStore</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.</span>
<span class="sd"> The policy checkpoints are written by the trainer and read by the rollout worker.</span>

View File

@@ -198,7 +198,8 @@
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStore</span><span class="p">,</span> <span class="n">DataStoreParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStoreParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.checkpoint_data_store</span> <span class="k">import</span> <span class="n">CheckpointDataStore</span>
<span class="kn">from</span> <span class="nn">minio</span> <span class="k">import</span> <span class="n">Minio</span>
<span class="kn">from</span> <span class="nn">minio.error</span> <span class="k">import</span> <span class="n">ResponseError</span>
<span class="kn">from</span> <span class="nn">configparser</span> <span class="k">import</span> <span class="n">ConfigParser</span><span class="p">,</span> <span class="n">Error</span>
@@ -222,7 +223,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">expt_dir</span> <span class="o">=</span> <span class="n">expt_dir</span>
<div class="viewcode-block" id="S3DataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.s3_data_store.S3DataStore">[docs]</a><span class="k">class</span> <span class="nc">S3DataStore</span><span class="p">(</span><span class="n">DataStore</span><span class="p">):</span>
<div class="viewcode-block" id="S3DataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.s3_data_store.S3DataStore">[docs]</a><span class="k">class</span> <span class="nc">S3DataStore</span><span class="p">(</span><span class="n">CheckpointDataStore</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.</span>
<span class="sd"> The policy checkpoints are written by the trainer and read by the rollout worker.</span>