1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00
This commit is contained in:
Gal Leibovich
2019-06-16 11:11:21 +03:00
committed by GitHub
parent 8df3c46756
commit 7eb884c5b2
107 changed files with 2200 additions and 495 deletions

View File

@@ -278,19 +278,6 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">memory_backend_params</span><span class="o">.</span><span class="n">run_type</span> <span class="o">!=</span> <span class="s1">&#39;trainer&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">set_memory_backend</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="p">)</span>
<span class="k">if</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">PickledReplayBuffer</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a pickled replay buffer. Pickled file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_pickled</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">CsvDataset</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a replay buffer from a CSV file. CSV file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_csv</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Trying to load a replay buffer using an unsupported method - </span><span class="si">{}</span><span class="s1">. &#39;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_memory</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_chief</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shared_memory_scratchpad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_lookup_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">)</span>
@@ -444,7 +431,39 @@
<span class="bp">self</span><span class="o">.</span><span class="n">input_filter</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_filter</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
<span class="p">[</span><span class="n">network</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span> <span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">()]</span></div>
<span class="p">[</span><span class="n">network</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span> <span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">initialize_session_dependent_components</span><span class="p">()</span></div>
<div class="viewcode-block" id="Agent.initialize_session_dependent_components"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.initialize_session_dependent_components">[docs]</a> <span class="k">def</span> <span class="nf">initialize_session_dependent_components</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Initialize components which require a session as part of their initialization.</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Loading a memory from a CSV file, requires an input filter to filter through the data.</span>
<span class="c1"># The filter needs a session before it can be used.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">load_memory_from_file</span><span class="p">()</span></div>
<div class="viewcode-block" id="Agent.load_memory_from_file"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.load_memory_from_file">[docs]</a> <span class="k">def</span> <span class="nf">load_memory_from_file</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Load memory transitions from a file.</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">PickledReplayBuffer</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a pickled replay buffer. Pickled file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_pickled</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">CsvDataset</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a replay buffer from a CSV file. CSV file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_filter</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Trying to load a replay buffer using an unsupported method - </span><span class="si">{}</span><span class="s1">. &#39;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">))</span></div>
<div class="viewcode-block" id="Agent.register_signal"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.register_signal">[docs]</a> <span class="k">def</span> <span class="nf">register_signal</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">signal_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dump_one_value_per_episode</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">dump_one_value_per_step</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Signal</span><span class="p">:</span>
@@ -868,7 +887,10 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_train</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_batch_rl_training</span><span class="p">:</span>
<span class="c1"># when training an agent for generating a dataset in batch-rl, we don&#39;t want it to be counted as part of</span>
<span class="c1"># the training epochs. we only care for training epochs in batch-rl anyway.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
@@ -1229,7 +1251,15 @@
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">TrainingIteration</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span><span class="p">,</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">EnvironmentSteps</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span><span class="p">,</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">WallClockTime</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">get_current_wall_clock_time</span><span class="p">(),</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="p">]</span></div>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="p">]</span>
<div class="viewcode-block" id="Agent.freeze_memory"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.freeze_memory">[docs]</a> <span class="k">def</span> <span class="nf">freeze_memory</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Shuffle episodes in the memory and freeze it to make sure that no extra data is being pushed anymore.</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;shuffle_episodes&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;freeze&#39;</span><span class="p">)</span></div></div>
</pre></div>
</div>