mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
TD3 (#338)
This commit is contained in:
@@ -208,6 +208,7 @@
|
||||
<span class="kn">import</span> <span class="nn">random</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">Transition</span><span class="p">,</span> <span class="n">Episode</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.filters.filter</span> <span class="k">import</span> <span class="n">InputFilter</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.memory</span> <span class="k">import</span> <span class="n">Memory</span><span class="p">,</span> <span class="n">MemoryGranularity</span><span class="p">,</span> <span class="n">MemoryParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">ReaderWriterLock</span><span class="p">,</span> <span class="n">ProgressBar</span>
|
||||
@@ -591,11 +592,12 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="n">mean</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">csv_dataset</span><span class="p">:</span> <span class="n">CsvDataset</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">csv_dataset</span><span class="p">:</span> <span class="n">CsvDataset</span><span class="p">,</span> <span class="n">input_filter</span><span class="p">:</span> <span class="n">InputFilter</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Restore the replay buffer contents from a csv file.</span>
|
||||
<span class="sd"> The csv file is assumed to include a list of transitions.</span>
|
||||
<span class="sd"> :param csv_dataset: A construct which holds the dataset parameters</span>
|
||||
<span class="sd"> :param input_filter: A filter used to filter the CSV data before feeding it to the memory.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">assert_not_frozen</span><span class="p">()</span>
|
||||
|
||||
@@ -612,18 +614,30 @@
|
||||
<span class="k">for</span> <span class="n">e_id</span> <span class="ow">in</span> <span class="n">episode_ids</span><span class="p">:</span>
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">e_id</span><span class="p">)</span>
|
||||
<span class="n">df_episode_transitions</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">df</span><span class="p">[</span><span class="s1">'episode_id'</span><span class="p">]</span> <span class="o">==</span> <span class="n">e_id</span><span class="p">]</span>
|
||||
<span class="n">input_filter</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">df_episode_transitions</span><span class="p">)</span> <span class="o"><</span> <span class="mi">2</span><span class="p">:</span>
|
||||
<span class="c1"># we have to have at least 2 rows in each episode for creating a transition</span>
|
||||
<span class="k">continue</span>
|
||||
|
||||
<span class="n">episode</span> <span class="o">=</span> <span class="n">Episode</span><span class="p">()</span>
|
||||
<span class="n">transitions</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">current_transition</span><span class="p">),</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">next_transition</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">df_episode_transitions</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">(),</span>
|
||||
<span class="n">df_episode_transitions</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">()):</span>
|
||||
<span class="n">state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">current_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
|
||||
<span class="n">next_state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">next_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
|
||||
|
||||
<span class="n">episode</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span>
|
||||
<span class="n">transitions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||||
<span class="n">Transition</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">state</span><span class="p">},</span>
|
||||
<span class="n">action</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'action'</span><span class="p">],</span> <span class="n">reward</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'reward'</span><span class="p">],</span>
|
||||
<span class="n">next_state</span><span class="o">=</span><span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">next_state</span><span class="p">},</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">info</span><span class="o">=</span><span class="p">{</span><span class="s1">'all_action_probabilities'</span><span class="p">:</span>
|
||||
<span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'all_action_probabilities'</span><span class="p">])}))</span>
|
||||
<span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'all_action_probabilities'</span><span class="p">])}),</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="n">transitions</span> <span class="o">=</span> <span class="n">input_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">transitions</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">transitions</span><span class="p">:</span>
|
||||
<span class="n">episode</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Set the last transition to end the episode</span>
|
||||
<span class="k">if</span> <span class="n">csv_dataset</span><span class="o">.</span><span class="n">is_episodic</span><span class="p">:</span>
|
||||
@@ -635,8 +649,6 @@
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episode_ids</span><span class="p">))</span>
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shuffle_episodes</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">freeze</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Freezing the replay buffer does not allow any new transitions to be added to the memory.</span>
|
||||
|
||||
Reference in New Issue
Block a user