1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

SAC algorithm (#282)

* SAC algorithm

* SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train.
gym_environment - fixing an error in access to gym.spaces

* Soft Actor Critic - code cleanup

* code cleanup

* V-head initialization fix

* SAC benchmarks

* SAC Documentation

* typo fix

* documentation fixes

* documentation and version update

* README typo
This commit is contained in:
guyk1971
2019-05-01 18:37:49 +03:00
committed by shadiendrawis
parent 33dc29ee99
commit 74db141d5e
92 changed files with 2812 additions and 402 deletions

View File

@@ -463,7 +463,12 @@
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while deleting NFS PVC&quot;</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">return</span> <span class="kc">True</span></div>
<span class="k">return</span> <span class="kc">True</span>
<span class="k">def</span> <span class="nf">setup_checkpoint_dir</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crd</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">if</span> <span class="n">crd</span><span class="p">:</span>
<span class="c1"># TODO: find a way to upload this to the deployed nfs store.</span>
<span class="k">pass</span></div>
</pre></div>
</div>

View File

@@ -257,6 +257,9 @@
<span class="k">return</span> <span class="kc">True</span>
<span class="k">def</span> <span class="nf">save_to_store</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_save_to_store</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_save_to_store</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and</span>
<span class="sd"> uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.</span>
@@ -268,24 +271,32 @@
<span class="c1"># Acquire lock</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">put_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">LOCKFILE</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">(</span><span class="sa">b</span><span class="s1">&#39;&#39;</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">state_file</span> <span class="o">=</span> <span class="n">CheckpointStateFile</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">))</span>
<span class="n">state_file</span> <span class="o">=</span> <span class="n">CheckpointStateFile</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">))</span>
<span class="k">if</span> <span class="n">state_file</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
<span class="n">ckpt_state</span> <span class="o">=</span> <span class="n">state_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
<span class="n">checkpoint_file</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">for</span> <span class="n">root</span><span class="p">,</span> <span class="n">dirs</span><span class="p">,</span> <span class="n">files</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">walk</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">):</span>
<span class="k">for</span> <span class="n">root</span><span class="p">,</span> <span class="n">dirs</span><span class="p">,</span> <span class="n">files</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">walk</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">):</span>
<span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">files</span><span class="p">:</span>
<span class="k">if</span> <span class="n">filename</span> <span class="o">==</span> <span class="n">CheckpointStateFile</span><span class="o">.</span><span class="n">checkpoint_state_filename</span><span class="p">:</span>
<span class="n">checkpoint_file</span> <span class="o">=</span> <span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="n">filename</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="n">ckpt_state</span><span class="o">.</span><span class="n">name</span><span class="p">):</span>
<span class="n">abs_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fput_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">rel_name</span><span class="p">,</span> <span class="n">abs_name</span><span class="p">)</span>
<span class="n">abs_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_file</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">checkpoint_file</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fput_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">rel_name</span><span class="p">,</span> <span class="n">abs_name</span><span class="p">)</span>
<span class="c1"># upload Finished if present</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">FINISHED</span><span class="o">.</span><span class="n">value</span><span class="p">)):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">put_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">FINISHED</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">(</span><span class="sa">b</span><span class="s1">&#39;&#39;</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># upload Ready if present</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">)):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">put_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">(</span><span class="sa">b</span><span class="s1">&#39;&#39;</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># release lock</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">remove_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">LOCKFILE</span><span class="o">.</span><span class="n">value</span><span class="p">)</span>
@@ -301,6 +312,7 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span><span class="p">,</span> <span class="s1">&#39;gifs&#39;</span><span class="p">)):</span>
<span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span><span class="p">,</span> <span class="s1">&#39;gifs&#39;</span><span class="p">)):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fput_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span><span class="p">,</span> <span class="s1">&#39;gifs&#39;</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
<span class="k">except</span> <span class="n">ResponseError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while saving to S3&quot;</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
@@ -337,6 +349,18 @@
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="k">pass</span>
<span class="c1"># Check if there&#39;s a ready file</span>
<span class="n">objects</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">list_objects_v2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">next</span><span class="p">(</span><span class="n">objects</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">try</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fget_object</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">,</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">))</span>
<span class="p">)</span>
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="k">pass</span>
<span class="n">checkpoint_state</span> <span class="o">=</span> <span class="n">state_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
<span class="k">if</span> <span class="n">checkpoint_state</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">objects</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">list_objects_v2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="n">checkpoint_state</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">recursive</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
@@ -346,7 +370,11 @@
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fget_object</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">obj</span><span class="o">.</span><span class="n">object_name</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
<span class="k">except</span> <span class="n">ResponseError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while loading from S3&quot;</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span></div>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while loading from S3&quot;</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setup_checkpoint_dir</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crd</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">if</span> <span class="n">crd</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_save_to_store</span><span class="p">(</span><span class="n">crd</span><span class="p">)</span></div>
</pre></div>
</div>