mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
This commit is contained in:
@@ -179,6 +179,7 @@
|
||||
|
||||
<h1>Source code for rl_coach.memories.episodic.episodic_experience_replay</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
@@ -193,14 +194,19 @@
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="kn">import</span> <span class="nn">ast</span>
|
||||
<span class="kn">import</span> <span class="nn">math</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Any</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
<span class="kn">import</span> <span class="nn">random</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">Transition</span><span class="p">,</span> <span class="n">Episode</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.memory</span> <span class="k">import</span> <span class="n">Memory</span><span class="p">,</span> <span class="n">MemoryGranularity</span><span class="p">,</span> <span class="n">MemoryParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">ReaderWriterLock</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">ReaderWriterLock</span><span class="p">,</span> <span class="n">ProgressBar</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">CsvDataset</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">EpisodicExperienceReplayParameters</span><span class="p">(</span><span class="n">MemoryParameters</span><span class="p">):</span>
|
||||
@@ -208,6 +214,7 @@
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">n_step</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># for OPE we'll want a value < 1</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
@@ -220,7 +227,9 @@
|
||||
<span class="sd"> calculations of total return and other values that depend on the sequential behavior of the transitions</span>
|
||||
<span class="sd"> in the episode.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">MemoryGranularity</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span><span class="o">=</span><span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">),</span> <span class="n">n_step</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">MemoryGranularity</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">),</span> <span class="n">n_step</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="n">train_to_eval_ratio</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param max_size: the maximum number of transitions or episodes to hold in the memory</span>
|
||||
<span class="sd"> """</span>
|
||||
@@ -232,8 +241,11 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions_in_complete_episodes</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span> <span class="o">=</span> <span class="n">ReaderWriterLock</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># used in batch-rl</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># used in batch-rl</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">=</span> <span class="n">train_to_eval_ratio</span> <span class="c1"># used in batch-rl</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">length</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">length</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get the number of episodes in the ER (even if they are not complete)</span>
|
||||
<span class="sd"> """</span>
|
||||
@@ -255,6 +267,9 @@
|
||||
<span class="k">def</span> <span class="nf">num_transitions_in_complete_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions_in_complete_episodes</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_last_training_set_episode_id</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">is_consecutive_transitions</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Sample a batch of transitions from the replay buffer. If the requested size is larger than the number</span>
|
||||
@@ -272,7 +287,7 @@
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">transition_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">length</span><span class="p">())</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">transition_idx</span><span class="o">-</span><span class="n">size</span><span class="p">:</span><span class="n">transition_idx</span><span class="p">]</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">transition_idx</span> <span class="o">-</span> <span class="n">size</span><span class="p">:</span><span class="n">transition_idx</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">transitions_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">(),</span> <span class="n">size</span><span class="o">=</span><span class="n">size</span><span class="p">)</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">transitions_idx</span><span class="p">]</span>
|
||||
@@ -285,6 +300,78 @@
|
||||
|
||||
<span class="k">return</span> <span class="n">batch</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_episode_for_transition</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">:</span> <span class="n">Transition</span><span class="p">)</span> <span class="o">-></span> <span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="n">Episode</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get the episode from which that transition came from.</span>
|
||||
<span class="sd"> :param transition: The transition to lookup the episode for</span>
|
||||
<span class="sd"> :return: (Episode number, the episode) or (-1, None) if could not find a matching episode.</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">episode</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">transition</span> <span class="ow">in</span> <span class="n">episode</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">i</span><span class="p">,</span> <span class="n">episode</span>
|
||||
<span class="k">return</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="kc">None</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">shuffle_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Shuffle all the episodes in the replay buffer</span>
|
||||
<span class="sd"> :return:</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">transitions</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="o">.</span><span class="n">transitions</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_shuffled_data_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.</span>
|
||||
<span class="sd"> If the requested size is larger than the number of samples available in the replay buffer then the batch will</span>
|
||||
<span class="sd"> return empty. The last returned batch may be smaller than the size requested, to accommodate for all the</span>
|
||||
<span class="sd"> transitions in the replay buffer.</span>
|
||||
|
||||
<span class="sd"> :param size: the size of the batch to return</span>
|
||||
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o"><</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'train_to_eval_ratio should be in the (0, 1] range.'</span><span class="p">)</span>
|
||||
|
||||
<span class="n">transition</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="nb">round</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">())]</span>
|
||||
<span class="n">episode_num</span><span class="p">,</span> <span class="n">episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_episode_for_transition</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">=</span> <span class="n">episode_num</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="o">=</span> \
|
||||
<span class="nb">len</span><span class="p">([</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="p">])</span>
|
||||
|
||||
<span class="n">shuffled_transition_indices</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span><span class="p">))</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># The last batch drawn will usually be < batch_size (=the size variable)</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span> <span class="o">/</span> <span class="n">size</span><span class="p">)):</span>
|
||||
<span class="n">sample_data</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">shuffled_transition_indices</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span><span class="p">:</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">size</span><span class="p">]]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
|
||||
<span class="k">yield</span> <span class="n">sample_data</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_all_complete_episodes_transitions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get all the transitions from all the complete episodes in the buffer</span>
|
||||
<span class="sd"> :return: a list of transitions</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">()]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_all_complete_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get all the transitions from all the complete episodes in the buffer</span>
|
||||
<span class="sd"> :return: a list of transitions</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_complete_episodes</span><span class="p">())</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_episode_id</span><span class="p">,</span> <span class="n">end_episode_id</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get all the transitions from all the complete episodes in the buffer matching the given episode range</span>
|
||||
<span class="sd"> :return: a list of transitions</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">start_episode_id</span><span class="p">:</span><span class="n">end_episode_id</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_enforce_max_length</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Make sure that the size of the replay buffer does not pass the maximum size allowed.</span>
|
||||
@@ -368,7 +455,7 @@
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">store_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode</span><span class="p">:</span> <span class="n">Episode</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">store_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode</span><span class="p">:</span> <span class="n">Episode</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Store a new episode in the memory.</span>
|
||||
<span class="sd"> :param episode: the new episode to store</span>
|
||||
@@ -391,7 +478,7 @@
|
||||
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="k">def</span> <span class="nf">get_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Returns the episode in the given index. If the episode does not exist, returns None instead.</span>
|
||||
<span class="sd"> :param episode_index: the index of the episode to return</span>
|
||||
@@ -436,7 +523,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># for API compatibility</span>
|
||||
<span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Returns the episode in the given index. If the episode does not exist, returns None instead.</span>
|
||||
<span class="sd"> :param episode_index: the index of the episode to return</span>
|
||||
@@ -494,7 +581,51 @@
|
||||
<span class="n">mean</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">([</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="k">for</span> <span class="n">transition</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">])</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="n">mean</span></div>
|
||||
<span class="k">return</span> <span class="n">mean</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">csv_dataset</span><span class="p">:</span> <span class="n">CsvDataset</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Restore the replay buffer contents from a csv file.</span>
|
||||
<span class="sd"> The csv file is assumed to include a list of transitions.</span>
|
||||
<span class="sd"> :param csv_dataset: A construct which holds the dataset parameters</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">csv_dataset</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">)</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Warning! The number of transitions to load into the replay buffer (</span><span class="si">{}</span><span class="s2">) is "</span>
|
||||
<span class="s2">"bigger than the max size of the replay buffer (</span><span class="si">{}</span><span class="s2">). The excessive transitions will "</span>
|
||||
<span class="s2">"not be stored."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
|
||||
|
||||
<span class="n">episode_ids</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s1">'episode_id'</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">()</span>
|
||||
<span class="n">progress_bar</span> <span class="o">=</span> <span class="n">ProgressBar</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episode_ids</span><span class="p">))</span>
|
||||
<span class="n">state_columns</span> <span class="o">=</span> <span class="p">[</span><span class="n">col</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">df</span><span class="o">.</span><span class="n">columns</span> <span class="k">if</span> <span class="n">col</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'state_feature'</span><span class="p">)]</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">e_id</span> <span class="ow">in</span> <span class="n">episode_ids</span><span class="p">:</span>
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">e_id</span><span class="p">)</span>
|
||||
<span class="n">df_episode_transitions</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">df</span><span class="p">[</span><span class="s1">'episode_id'</span><span class="p">]</span> <span class="o">==</span> <span class="n">e_id</span><span class="p">]</span>
|
||||
<span class="n">episode</span> <span class="o">=</span> <span class="n">Episode</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">current_transition</span><span class="p">),</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">next_transition</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">df_episode_transitions</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">(),</span>
|
||||
<span class="n">df_episode_transitions</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">()):</span>
|
||||
<span class="n">state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">current_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
|
||||
<span class="n">next_state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">next_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
|
||||
|
||||
<span class="n">episode</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span>
|
||||
<span class="n">Transition</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">state</span><span class="p">},</span>
|
||||
<span class="n">action</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'action'</span><span class="p">],</span> <span class="n">reward</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'reward'</span><span class="p">],</span>
|
||||
<span class="n">next_state</span><span class="o">=</span><span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">next_state</span><span class="p">},</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">info</span><span class="o">=</span><span class="p">{</span><span class="s1">'all_action_probabilities'</span><span class="p">:</span>
|
||||
<span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'all_action_probabilities'</span><span class="p">])}))</span>
|
||||
|
||||
<span class="c1"># Set the last transition to end the episode</span>
|
||||
<span class="k">if</span> <span class="n">csv_dataset</span><span class="o">.</span><span class="n">is_episodic</span><span class="p">:</span>
|
||||
<span class="n">episode</span><span class="o">.</span><span class="n">get_last_transition</span><span class="p">()</span><span class="o">.</span><span class="n">game_over</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">store_episode</span><span class="p">(</span><span class="n">episode</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># close the progress bar</span>
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episode_ids</span><span class="p">))</span>
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shuffle_episodes</span><span class="p">()</span></div>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -194,10 +194,10 @@
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Any</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
|
||||
<span class="kn">import</span> <span class="nn">pickle</span>
|
||||
<span class="kn">import</span> <span class="nn">sys</span>
|
||||
<span class="kn">import</span> <span class="nn">time</span>
|
||||
<span class="kn">import</span> <span class="nn">random</span>
|
||||
<span class="kn">import</span> <span class="nn">math</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
@@ -252,7 +252,6 @@
|
||||
<span class="sd"> Sample a batch of transitions form the replay buffer. If the requested size is larger than the number</span>
|
||||
<span class="sd"> of samples available in the replay buffer then the batch will return empty.</span>
|
||||
<span class="sd"> :param size: the size of the batch to sample</span>
|
||||
<span class="sd"> :param beta: the beta parameter used for importance sampling</span>
|
||||
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
|
||||
@@ -272,6 +271,28 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="n">batch</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_shuffled_data_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.</span>
|
||||
<span class="sd"> If the requested size is larger than the number of samples available in the replay buffer then the batch will</span>
|
||||
<span class="sd"> return empty. The last returned batch may be smaller than the size requested, to accommodate for all the</span>
|
||||
<span class="sd"> transitions in the replay buffer.</span>
|
||||
|
||||
<span class="sd"> :param size: the size of the batch to return</span>
|
||||
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
|
||||
<span class="n">shuffled_transition_indices</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">)))</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># we deliberately drop some of the ending data which is left after dividing to batches of size `size`</span>
|
||||
<span class="c1"># for i in range(math.ceil(len(shuffled_transition_indices) / size)):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span> <span class="o">/</span> <span class="n">size</span><span class="p">)):</span>
|
||||
<span class="n">sample_data</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">shuffled_transition_indices</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span><span class="p">:</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">size</span><span class="p">]]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
|
||||
<span class="k">yield</span> <span class="n">sample_data</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_enforce_max_length</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Make sure that the size of the replay buffer does not pass the maximum size allowed.</span>
|
||||
@@ -395,7 +416,7 @@
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s1">'wb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">file</span><span class="p">:</span>
|
||||
<span class="n">pickle</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">,</span> <span class="n">file</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">load_pickled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Restore the replay buffer contents from a pickle file.</span>
|
||||
<span class="sd"> The pickle file is assumed to include a list of transitions.</span>
|
||||
@@ -418,6 +439,7 @@
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">transition_idx</span><span class="p">)</span>
|
||||
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span></div>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user