mirror of
https://github.com/gryf/coach.git
synced 2026-01-06 22:04:15 +01:00
SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
This commit is contained in:
@@ -194,10 +194,10 @@
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Any</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
|
||||
<span class="kn">import</span> <span class="nn">pickle</span>
|
||||
<span class="kn">import</span> <span class="nn">sys</span>
|
||||
<span class="kn">import</span> <span class="nn">time</span>
|
||||
<span class="kn">import</span> <span class="nn">random</span>
|
||||
<span class="kn">import</span> <span class="nn">math</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
@@ -252,7 +252,6 @@
|
||||
<span class="sd"> Sample a batch of transitions form the replay buffer. If the requested size is larger than the number</span>
|
||||
<span class="sd"> of samples available in the replay buffer then the batch will return empty.</span>
|
||||
<span class="sd"> :param size: the size of the batch to sample</span>
|
||||
<span class="sd"> :param beta: the beta parameter used for importance sampling</span>
|
||||
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
|
||||
@@ -272,6 +271,28 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="n">batch</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_shuffled_data_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.</span>
|
||||
<span class="sd"> If the requested size is larger than the number of samples available in the replay buffer then the batch will</span>
|
||||
<span class="sd"> return empty. The last returned batch may be smaller than the size requested, to accommodate for all the</span>
|
||||
<span class="sd"> transitions in the replay buffer.</span>
|
||||
|
||||
<span class="sd"> :param size: the size of the batch to return</span>
|
||||
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
|
||||
<span class="n">shuffled_transition_indices</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">)))</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># we deliberately drop some of the ending data which is left after dividing to batches of size `size`</span>
|
||||
<span class="c1"># for i in range(math.ceil(len(shuffled_transition_indices) / size)):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span> <span class="o">/</span> <span class="n">size</span><span class="p">)):</span>
|
||||
<span class="n">sample_data</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">shuffled_transition_indices</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span><span class="p">:</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">size</span><span class="p">]]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
|
||||
<span class="k">yield</span> <span class="n">sample_data</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_enforce_max_length</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Make sure that the size of the replay buffer does not pass the maximum size allowed.</span>
|
||||
@@ -395,7 +416,7 @@
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s1">'wb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">file</span><span class="p">:</span>
|
||||
<span class="n">pickle</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">,</span> <span class="n">file</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">load_pickled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Restore the replay buffer contents from a pickle file.</span>
|
||||
<span class="sd"> The pickle file is assumed to include a list of transitions.</span>
|
||||
@@ -418,6 +439,7 @@
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">transition_idx</span><span class="p">)</span>
|
||||
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span></div>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user