mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
TD3 (#338)
This commit is contained in:
@@ -278,19 +278,6 @@
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">memory_backend_params</span><span class="o">.</span><span class="n">run_type</span> <span class="o">!=</span> <span class="s1">'trainer'</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">set_memory_backend</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">PickledReplayBuffer</span><span class="p">):</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"Loading a pickled replay buffer. Pickled file path: </span><span class="si">{}</span><span class="s2">"</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_pickled</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">CsvDataset</span><span class="p">):</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"Loading a replay buffer from a CSV file. CSV file path: </span><span class="si">{}</span><span class="s2">"</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_csv</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'Trying to load a replay buffer using an unsupported method - </span><span class="si">{}</span><span class="s1">. '</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">))</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_memory</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_chief</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shared_memory_scratchpad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_lookup_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">)</span>
|
||||
|
||||
@@ -444,7 +431,39 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_filter</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">output_filter</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
|
||||
<span class="p">[</span><span class="n">network</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span> <span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">()]</span></div>
|
||||
<span class="p">[</span><span class="n">network</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span> <span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">initialize_session_dependent_components</span><span class="p">()</span></div>
|
||||
|
||||
<div class="viewcode-block" id="Agent.initialize_session_dependent_components"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.initialize_session_dependent_components">[docs]</a> <span class="k">def</span> <span class="nf">initialize_session_dependent_components</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Initialize components which require a session as part of their initialization.</span>
|
||||
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="c1"># Loading a memory from a CSV file, requires an input filter to filter through the data.</span>
|
||||
<span class="c1"># The filter needs a session before it can be used.</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">load_memory_from_file</span><span class="p">()</span></div>
|
||||
|
||||
<div class="viewcode-block" id="Agent.load_memory_from_file"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.load_memory_from_file">[docs]</a> <span class="k">def</span> <span class="nf">load_memory_from_file</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Load memory transitions from a file.</span>
|
||||
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">PickledReplayBuffer</span><span class="p">):</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"Loading a pickled replay buffer. Pickled file path: </span><span class="si">{}</span><span class="s2">"</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_pickled</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">CsvDataset</span><span class="p">):</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"Loading a replay buffer from a CSV file. CSV file path: </span><span class="si">{}</span><span class="s2">"</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_filter</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'Trying to load a replay buffer using an unsupported method - </span><span class="si">{}</span><span class="s1">. '</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">))</span></div>
|
||||
|
||||
<div class="viewcode-block" id="Agent.register_signal"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.register_signal">[docs]</a> <span class="k">def</span> <span class="nf">register_signal</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">signal_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dump_one_value_per_episode</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">dump_one_value_per_step</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Signal</span><span class="p">:</span>
|
||||
@@ -868,7 +887,10 @@
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_train</span><span class="p">():</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_batch_rl_training</span><span class="p">:</span>
|
||||
<span class="c1"># when training an agent for generating a dataset in batch-rl, we don't want it to be counted as part of</span>
|
||||
<span class="c1"># the training epochs. we only care for training epochs in batch-rl anyway.</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
@@ -1229,7 +1251,15 @@
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">TrainingIteration</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span><span class="p">,</span>
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">EnvironmentSteps</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span><span class="p">,</span>
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">WallClockTime</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">get_current_wall_clock_time</span><span class="p">(),</span>
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="p">]</span></div>
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="p">]</span>
|
||||
|
||||
<div class="viewcode-block" id="Agent.freeze_memory"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.freeze_memory">[docs]</a> <span class="k">def</span> <span class="nf">freeze_memory</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Shuffle episodes in the memory and freeze it to make sure that no extra data is being pushed anymore.</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'shuffle_episodes'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'freeze'</span><span class="p">)</span></div></div>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user