mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
This commit is contained in:
@@ -298,7 +298,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">s3_access_key</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'ACCESS_KEY_ID'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">s3_secret_key</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'SECRET_ACCESS_KEY'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crd</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Deploys the memory backend and data stores if required.</span>
|
||||
<span class="sd"> """</span>
|
||||
@@ -308,6 +308,9 @@
|
||||
<span class="k">return</span> <span class="kc">False</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">nfs_pvc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">get_info</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># Upload checkpoints in checkpoint_restore_dir (if provided) to the data store</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">setup_checkpoint_dir</span><span class="p">(</span><span class="n">crd</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="kc">True</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">deploy_trainer</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
@@ -321,7 +324,6 @@
|
||||
|
||||
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--memory_backend_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">memory_backend_parameters</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--data_store_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
|
||||
<span class="n">name</span> <span class="o">=</span> <span class="s2">"</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
@@ -346,7 +348,7 @@
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"nfs-pvc"</span><span class="p">,</span>
|
||||
<span class="n">persistent_volume_claim</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nfs_pvc</span>
|
||||
<span class="p">)],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'OnFailure'</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
@@ -365,7 +367,7 @@
|
||||
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">'app'</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
|
||||
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
|
||||
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'OnFailure'</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
@@ -427,7 +429,7 @@
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"nfs-pvc"</span><span class="p">,</span>
|
||||
<span class="n">persistent_volume_claim</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nfs_pvc</span>
|
||||
<span class="p">)],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'OnFailure'</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
@@ -446,7 +448,7 @@
|
||||
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">'app'</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
|
||||
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
|
||||
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'OnFailure'</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
@@ -496,7 +498,7 @@
|
||||
<span class="k">return</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">pod</span> <span class="ow">in</span> <span class="n">pods</span><span class="o">.</span><span class="n">items</span><span class="p">:</span>
|
||||
<span class="n">Process</span><span class="p">(</span><span class="n">target</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tail_log_file</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">namespace</span><span class="p">,</span> <span class="n">path</span><span class="p">))</span><span class="o">.</span><span class="n">start</span><span class="p">()</span>
|
||||
<span class="n">Process</span><span class="p">(</span><span class="n">target</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tail_log_file</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">namespace</span><span class="p">,</span> <span class="n">path</span><span class="p">),</span> <span class="n">daemon</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span><span class="o">.</span><span class="n">start</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_tail_log_file</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pod_name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">,</span> <span class="n">namespace</span><span class="p">,</span> <span class="n">path</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
|
||||
@@ -528,7 +530,7 @@
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">pod</span><span class="p">:</span>
|
||||
<span class="k">return</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tail_log</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tail_log</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">tail_log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pod_name</span><span class="p">,</span> <span class="n">corev1_api</span><span class="p">):</span>
|
||||
<span class="k">while</span> <span class="kc">True</span><span class="p">:</span>
|
||||
@@ -562,9 +564,9 @@
|
||||
<span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">waiting</span><span class="o">.</span><span class="n">reason</span> <span class="o">==</span> <span class="s1">'CrashLoopBackOff'</span> <span class="ow">or</span> \
|
||||
<span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">waiting</span><span class="o">.</span><span class="n">reason</span> <span class="o">==</span> <span class="s1">'ImagePullBackOff'</span> <span class="ow">or</span> \
|
||||
<span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">waiting</span><span class="o">.</span><span class="n">reason</span> <span class="o">==</span> <span class="s1">'ErrImagePull'</span><span class="p">:</span>
|
||||
<span class="k">return</span>
|
||||
<span class="k">return</span> <span class="mi">1</span>
|
||||
<span class="k">if</span> <span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">terminated</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">return</span>
|
||||
<span class="k">return</span> <span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">terminated</span><span class="o">.</span><span class="n">exit_code</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">undeploy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
Reference in New Issue
Block a user