1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-25 01:51:28 +02:00

SAC algorithm (#282)

* SAC algorithm

* SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train.
gym_environment - fixing an error in access to gym.spaces

* Soft Actor Critic - code cleanup

* code cleanup

* V-head initialization fix

* SAC benchmarks

* SAC Documentation

* typo fix

* documentation fixes

* documentation and version update

* README typo
This commit is contained in:
guyk1971
2019-05-01 18:37:49 +03:00
committed by shadiendrawis
parent 33dc29ee99
commit 74db141d5e
92 changed files with 2812 additions and 402 deletions
Binary file not shown.

Before

Width:  |  Height:  |  Size: 51 KiB

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

+2
View File
@@ -179,6 +179,7 @@
<ul><li><a href="rl_coach/agents/acer_agent.html">rl_coach.agents.acer_agent</a></li>
<li><a href="rl_coach/agents/actor_critic_agent.html">rl_coach.agents.actor_critic_agent</a></li>
<li><a href="rl_coach/agents/agent.html">rl_coach.agents.agent</a></li>
<li><a href="rl_coach/agents/agent_interface.html">rl_coach.agents.agent_interface</a></li>
<li><a href="rl_coach/agents/bc_agent.html">rl_coach.agents.bc_agent</a></li>
<li><a href="rl_coach/agents/categorical_dqn_agent.html">rl_coach.agents.categorical_dqn_agent</a></li>
<li><a href="rl_coach/agents/cil_agent.html">rl_coach.agents.cil_agent</a></li>
@@ -195,6 +196,7 @@
<li><a href="rl_coach/agents/ppo_agent.html">rl_coach.agents.ppo_agent</a></li>
<li><a href="rl_coach/agents/qr_dqn_agent.html">rl_coach.agents.qr_dqn_agent</a></li>
<li><a href="rl_coach/agents/rainbow_dqn_agent.html">rl_coach.agents.rainbow_dqn_agent</a></li>
<li><a href="rl_coach/agents/soft_actor_critic_agent.html">rl_coach.agents.soft_actor_critic_agent</a></li>
<li><a href="rl_coach/agents/value_optimization_agent.html">rl_coach.agents.value_optimization_agent</a></li>
<li><a href="rl_coach/architectures/architecture.html">rl_coach.architectures.architecture</a></li>
<li><a href="rl_coach/architectures/network_wrapper.html">rl_coach.architectures.network_wrapper</a></li>
@@ -248,7 +248,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_gradient_updates</span> <span class="o">=</span> <span class="mi">5000</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ratio_of_replay</span> <span class="o">=</span> <span class="mi">4</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_to_start_replay</span> <span class="o">=</span> <span class="mi">10000</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.99</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.01</span>
<span class="bp">self</span><span class="o">.</span><span class="n">importance_weight_truncation</span> <span class="o">=</span> <span class="mf">10.0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">use_trust_region_optimization</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_KL_divergence</span> <span class="o">=</span> <span class="mf">1.0</span>
+140 -128
View File
@@ -214,6 +214,9 @@
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">Signal</span><span class="p">,</span> <span class="n">force_list</span>
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">dynamic_import_and_instantiate_module_from_params</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.backend.memory_impl</span> <span class="k">import</span> <span class="n">get_memory_backend</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">TimeTypes</span>
<span class="kn">from</span> <span class="nn">rl_coach.off_policy_evaluators.ope_manager</span> <span class="k">import</span> <span class="n">OpeManager</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">PickledReplayBuffer</span><span class="p">,</span> <span class="n">CsvDataset</span>
<div class="viewcode-block" id="Agent"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent">[docs]</a><span class="k">class</span> <span class="nc">Agent</span><span class="p">(</span><span class="n">AgentInterface</span><span class="p">):</span>
@@ -222,6 +225,15 @@
<span class="sd"> :param agent_parameters: A AgentParameters class instance with all the agent parameters</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="c1"># use seed</span>
<span class="k">if</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># we need to seed the RNG since the different processes are initialized with the same parent seed</span>
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">()</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span> <span class="o">=</span> <span class="n">agent_parameters</span>
<span class="bp">self</span><span class="o">.</span><span class="n">task_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">task_index</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_chief</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">task_id</span> <span class="o">==</span> <span class="mi">0</span>
@@ -229,10 +241,10 @@
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">shared_memory</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_memory</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shared_memory_scratchpad</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">shared_memory_scratchpad</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">name</span>
<span class="bp">self</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">parent</span>
<span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">full_name_id</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">full_name_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
<span class="c1"># TODO this needs to be sorted out. Why the duplicates for the agent&#39;s name?</span>
<span class="bp">self</span><span class="o">.</span><span class="n">full_name_id</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">full_name_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">name</span>
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">task_parameters</span><span class="p">)</span> <span class="o">==</span> <span class="n">DistributedTaskParameters</span><span class="p">:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Creating agent - name: </span><span class="si">{}</span><span class="s2"> task id: </span><span class="si">{}</span><span class="s2"> (may take up to 30 seconds due to &quot;</span>
@@ -264,9 +276,17 @@
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">set_memory_backend</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="p">)</span>
<span class="k">if</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading replay buffer from pickle. Pickle path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">PickledReplayBuffer</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a pickled replay buffer. Pickled file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_pickled</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">CsvDataset</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a replay buffer from a CSV file. CSV file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_csv</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Trying to load a replay buffer using an unsupported method - </span><span class="si">{}</span><span class="s1">. &#39;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_memory</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_chief</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shared_memory_scratchpad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_lookup_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">)</span>
@@ -327,6 +347,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">running_reward</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_target_network_update_step</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_phase_step</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">=</span> <span class="mi">0</span>
@@ -364,14 +385,9 @@
<span class="bp">self</span><span class="o">.</span><span class="n">discounted_return</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Discounted Return&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">in_action_space</span><span class="p">,</span> <span class="n">GoalsSpace</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">distance_from_goal</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Distance From Goal&#39;</span><span class="p">,</span> <span class="n">dump_one_value_per_step</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># use seed</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># we need to seed the RNG since the different processes are initialized with the same parent seed</span>
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">()</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">()</span>
<span class="c1"># batch rl</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ope_manager</span> <span class="o">=</span> <span class="n">OpeManager</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_batch_rl_training</span> <span class="k">else</span> <span class="kc">None</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">parent</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s1">&#39;LevelManager&#39;</span><span class="p">:</span>
@@ -408,6 +424,7 @@
<span class="nb">format</span><span class="p">(</span><span class="n">graph_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
<span class="n">level_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
<span class="n">agent_full_id</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">full_name_id</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;/&#39;</span><span class="p">)))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_index_name</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_logger_filenames</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">experiment_path</span><span class="p">,</span> <span class="n">logger_prefix</span><span class="o">=</span><span class="n">logger_prefix</span><span class="p">,</span>
<span class="n">add_timestamp</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">task_id</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">task_id</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_in_episode_signals</span><span class="p">:</span>
@@ -561,13 +578,17 @@
<span class="bp">self</span><span class="o">.</span><span class="n">accumulated_shaped_rewards_across_evaluation_episodes</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_successes_across_evaluation_episodes</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_evaluation_episodes_completed</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">==</span> <span class="s2">&quot;high&quot;</span><span class="p">:</span>
<span class="c1"># TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back</span>
<span class="c1"># if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == &quot;high&quot;:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span><span class="p">:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="s2">: Starting evaluation phase&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">))</span>
<span class="k">elif</span> <span class="n">ending_evaluation</span><span class="p">:</span>
<span class="c1"># we write to the next episode, because it could be that the current episode was already written</span>
<span class="c1"># to disk and then we won&#39;t write it again</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_current_time</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">evaluation_reward</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulated_rewards_across_evaluation_episodes</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_evaluation_episodes_completed</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span>
<span class="s1">&#39;Evaluation Reward&#39;</span><span class="p">,</span> <span class="n">evaluation_reward</span><span class="p">)</span>
@@ -577,9 +598,11 @@
<span class="n">success_rate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_successes_across_evaluation_episodes</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_evaluation_episodes_completed</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span>
<span class="s2">&quot;Success Rate&quot;</span><span class="p">,</span>
<span class="n">success_rate</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">==</span> <span class="s2">&quot;high&quot;</span><span class="p">:</span>
<span class="n">success_rate</span><span class="p">)</span>
<span class="c1"># TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back</span>
<span class="c1"># if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == &quot;high&quot;:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span><span class="p">:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="s2">: Finished evaluation phase. Success rate = </span><span class="si">{}</span><span class="s2">, Avg Total Reward = </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">round</span><span class="p">(</span><span class="n">success_rate</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">round</span><span class="p">(</span><span class="n">evaluation_reward</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span></div>
@@ -652,8 +675,11 @@
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># log all the signals to file</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span><span class="p">)</span>
<span class="n">current_time</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_current_time</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="n">current_time</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Training Iter&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Episode #&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Epoch&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;In Heatup&#39;</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;ER #Transitions&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;num_transitions&#39;</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;ER #Episodes&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;length&#39;</span><span class="p">))</span>
@@ -666,12 +692,17 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Update Target Network&#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">update_wall_clock_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">update_wall_clock_time</span><span class="p">(</span><span class="n">current_time</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Evaluation Reward&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Shaped Evaluation Reward&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Success Rate&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># The following signals are created with meaningful values only when an evaluation phase is completed.</span>
<span class="c1"># Creating with default NaNs for any HEATUP/TRAIN/TEST episode which is not the last in an evaluation phase</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Evaluation Reward&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Shaped Evaluation Reward&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Success Rate&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Inverse Propensity Score&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Direct Method Reward&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Doubly Robust&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Sequential Doubly Robust&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">for</span> <span class="n">signal</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">episode_signals</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="s2">/Mean&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">signal</span><span class="o">.</span><span class="n">name</span><span class="p">),</span> <span class="n">signal</span><span class="o">.</span><span class="n">get_mean</span><span class="p">())</span>
@@ -680,8 +711,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="s2">/Min&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">signal</span><span class="o">.</span><span class="n">name</span><span class="p">),</span> <span class="n">signal</span><span class="o">.</span><span class="n">get_min</span><span class="p">())</span>
<span class="c1"># dump</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_signals_to_csv_every_x_episodes</span> <span class="o">==</span> <span class="mi">0</span> \
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_signals_to_csv_every_x_episodes</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">dump_output_csv</span><span class="p">()</span></div>
<div class="viewcode-block" id="Agent.handle_episode_ended"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.handle_episode_ended">[docs]</a> <span class="k">def</span> <span class="nf">handle_episode_ended</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
@@ -717,10 +747,13 @@
<span class="bp">self</span><span class="o">.</span><span class="n">total_reward_in_current_episode</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">reward</span><span class="o">.</span><span class="n">reward_success_threshold</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_successes_across_evaluation_episodes</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_csv</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_csv</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span> <span class="o">==</span> <span class="n">TimeTypes</span><span class="o">.</span><span class="n">EpisodeNumber</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_log</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">==</span> <span class="s2">&quot;high&quot;</span><span class="p">:</span>
<span class="c1"># TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back</span>
<span class="c1"># if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == &quot;high&quot;:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_to_screen</span><span class="p">()</span></div>
<div class="viewcode-block" id="Agent.reset_internal_state"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.reset_internal_state">[docs]</a> <span class="k">def</span> <span class="nf">reset_internal_state</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
@@ -831,18 +864,25 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_train</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="k">for</span> <span class="n">training_step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">):</span>
<span class="c1"># TODO: this should be network dependent</span>
<span class="n">network_parameters</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="o">.</span><span class="n">values</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># At the moment we only support a single batch size for all the networks</span>
<span class="n">networks_parameters</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
<span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">==</span> <span class="n">networks_parameters</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">for</span> <span class="n">net</span> <span class="ow">in</span> <span class="n">networks_parameters</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">networks_parameters</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
<span class="c1"># we either go sequentially through the entire replay buffer in the batch RL mode,</span>
<span class="c1"># or sample randomly for the basic RL case.</span>
<span class="n">training_schedule</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;get_shuffled_data_generator&#39;</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span> <span class="k">if</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_batch_rl_training</span> <span class="k">else</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;sample&#39;</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span>
<span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">)]</span>
<span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">training_schedule</span><span class="p">:</span>
<span class="c1"># update counters</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="c1"># sample a batch and train on it</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;sample&#39;</span><span class="p">,</span> <span class="n">network_parameters</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
@@ -853,15 +893,19 @@
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">learn_from_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">+=</span> <span class="n">total_loss</span>
<span class="bp">self</span><span class="o">.</span><span class="n">unclipped_grads</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">unclipped_grads</span><span class="p">)</span>
<span class="c1"># TODO: the learning rate decay should be done through the network instead of here</span>
<span class="c1"># TODO: this only deals with the main network (if exists), need to do the same for other networks</span>
<span class="c1"># for instance, for DDPG, the LR signal is currently not shown. Probably should be done through the</span>
<span class="c1"># network directly instead of here</span>
<span class="c1"># decay learning rate</span>
<span class="k">if</span> <span class="n">network_parameters</span><span class="o">.</span><span class="n">learning_rate_decay_rate</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">if</span> <span class="s1">&#39;main&#39;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">learning_rate_decay_rate</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">curr_learning_rate</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">current_learning_rate</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">curr_learning_rate</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">network_parameters</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">curr_learning_rate</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">networks_parameters</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">any</span><span class="p">([</span><span class="n">network</span><span class="o">.</span><span class="n">has_target</span> <span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">()])</span> \
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_update_online_weights_to_target</span><span class="p">():</span>
@@ -877,6 +921,12 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">imitation</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_to_screen</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_csv</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span> <span class="o">==</span> <span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span>
<span class="c1"># in BatchRL, or imitation learning, the agent never acts, so we have to get the stats out here.</span>
<span class="c1"># we dump the data out every epoch</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_log</span><span class="p">()</span>
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
@@ -919,10 +969,11 @@
<span class="k">return</span> <span class="n">batches_dict</span></div>
<div class="viewcode-block" id="Agent.act"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.act">[docs]</a> <span class="k">def</span> <span class="nf">act</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActionInfo</span><span class="p">:</span>
<div class="viewcode-block" id="Agent.act"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.act">[docs]</a> <span class="k">def</span> <span class="nf">act</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">ActionType</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActionInfo</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Given the agents current knowledge, decide on the next action to apply to the environment</span>
<span class="sd"> :param action: An action to take, overriding whatever the current policy is</span>
<span class="sd"> :return: An ActionInfo object, which contains the action and any additional info from the action decision process</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
@@ -935,20 +986,28 @@
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_steps_counter</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="c1"># decide on the action</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">heatup_using_network_decisions</span><span class="p">:</span>
<span class="c1"># random action</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">sample_with_info</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># informed action</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># before choosing an action, first use the pre_network_filter to filter out the current state</span>
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="n">update_filter_internal_state</span><span class="p">)</span>
<span class="k">if</span> <span class="n">action</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">heatup_using_network_decisions</span><span class="p">:</span>
<span class="c1"># random action</span>
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">sample_with_info</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">choose_action</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
<span class="c1"># informed action</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># before choosing an action, first use the pre_network_filter to filter out the current state</span>
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="n">update_filter_internal_state</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span>
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">choose_action</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action</span><span class="p">,</span> <span class="n">ActionInfo</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span> <span class="o">=</span> <span class="n">action</span>
<span class="c1"># output filters are explicitly applied after recording self.last_action_info. This is</span>
<span class="c1"># because the output filters may change the representation of the action so that the agent</span>
<span class="c1"># can no longer use the transition in it&#39;s replay buffer. It is possible that these filters</span>
<span class="c1"># could be moved to the environment instead, but they are here now for historical reasons.</span>
<span class="n">filtered_action_info</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span><span class="p">)</span>
<span class="k">return</span> <span class="n">filtered_action_info</span></div>
@@ -1030,37 +1089,35 @@
<span class="c1"># make agent specific changes to the transition if needed</span>
<span class="n">transition</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_transition_before_adding_to_replay_buffer</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
<span class="c1"># merge the intrinsic reward in</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">scale_external_reward_by_intrinsic_reward_value</span><span class="p">:</span>
<span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="o">=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span><span class="o">.</span><span class="n">action_intrinsic_reward</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="o">=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span><span class="o">.</span><span class="n">action_intrinsic_reward</span>
<span class="c1"># sum up the total shaped reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_shaped_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">env_response</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shaped_reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">env_response</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
<span class="c1"># add action info to transition</span>
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">parent</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;CompositeAgent&#39;</span><span class="p">:</span>
<span class="n">transition</span><span class="o">.</span><span class="n">add_info</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">last_action_info</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">transition</span><span class="o">.</span><span class="n">add_info</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
<span class="c1"># create and store the transition</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">in</span> <span class="p">[</span><span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">,</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span><span class="p">]:</span>
<span class="c1"># for episodic memories we keep the transitions in a local buffer until the episode is ended.</span>
<span class="c1"># for regular memories we insert the transitions directly to the memory</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">EpisodicExperienceReplay</span><span class="p">)</span> \
<span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">store_transitions_only_when_episodes_are_terminated</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;store&#39;</span><span class="p">,</span> <span class="n">transition</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">env_response</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">env_response</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_in_episode_signals</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_step_in_episode_log</span><span class="p">()</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">observe_transition</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span></div>
<span class="k">return</span> <span class="n">transition</span><span class="o">.</span><span class="n">game_over</span></div>
<span class="k">def</span> <span class="nf">observe_transition</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">):</span>
<span class="c1"># sum up the total shaped reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_shaped_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shaped_reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
<span class="c1"># create and store the transition</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">in</span> <span class="p">[</span><span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">,</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span><span class="p">]:</span>
<span class="c1"># for episodic memories we keep the transitions in a local buffer until the episode is ended.</span>
<span class="c1"># for regular memories we insert the transitions directly to the memory</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">EpisodicExperienceReplay</span><span class="p">)</span> \
<span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">store_transitions_only_when_episodes_are_terminated</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;store&#39;</span><span class="p">,</span> <span class="n">transition</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_in_episode_signals</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_step_in_episode_log</span><span class="p">()</span>
<span class="k">return</span> <span class="n">transition</span><span class="o">.</span><span class="n">game_over</span>
<div class="viewcode-block" id="Agent.post_training_commands"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.post_training_commands">[docs]</a> <span class="k">def</span> <span class="nf">post_training_commands</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
@@ -1145,60 +1202,6 @@
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">network</span><span class="o">.</span><span class="n">sync</span><span class="p">()</span></div>
<span class="c1"># TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create</span>
<span class="c1"># an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]</span>
<div class="viewcode-block" id="Agent.emulate_observe_on_trainer"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.emulate_observe_on_trainer">[docs]</a> <span class="k">def</span> <span class="nf">emulate_observe_on_trainer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">:</span> <span class="n">Transition</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> This emulates the observe using the transition obtained from the rollout worker on the training worker</span>
<span class="sd"> in case of distributed training.</span>
<span class="sd"> Given a response from the environment, distill the observation from it and store it for later use.</span>
<span class="sd"> The response should be a dictionary containing the performed action, the new observation and measurements,</span>
<span class="sd"> the reward, a game over flag and any additional information necessary.</span>
<span class="sd"> :return:</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># sum up the total shaped reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_shaped_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shaped_reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
<span class="c1"># create and store the transition</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">in</span> <span class="p">[</span><span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">,</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span><span class="p">]:</span>
<span class="c1"># for episodic memories we keep the transitions in a local buffer until the episode is ended.</span>
<span class="c1"># for regular memories we insert the transitions directly to the memory</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">EpisodicExperienceReplay</span><span class="p">)</span> \
<span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">store_transitions_only_when_episodes_are_terminated</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;store&#39;</span><span class="p">,</span> <span class="n">transition</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_in_episode_signals</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_step_in_episode_log</span><span class="p">()</span>
<span class="k">return</span> <span class="n">transition</span><span class="o">.</span><span class="n">game_over</span></div>
<span class="c1"># TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create</span>
<span class="c1"># an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]</span>
<div class="viewcode-block" id="Agent.emulate_act_on_trainer"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.emulate_act_on_trainer">[docs]</a> <span class="k">def</span> <span class="nf">emulate_act_on_trainer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">:</span> <span class="n">Transition</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActionInfo</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> This emulates the act using the transition obtained from the rollout worker on the training worker</span>
<span class="sd"> in case of distributed training.</span>
<span class="sd"> Given the agents current knowledge, decide on the next action to apply to the environment</span>
<span class="sd"> :return: an action and a dictionary containing any additional info from the action decision process</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># This agent never plays while training (e.g. behavioral cloning)</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="c1"># count steps (only when training or if we are in the evaluation worker)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">evaluate_only</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_steps_counter</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span> <span class="o">=</span> <span class="n">transition</span><span class="o">.</span><span class="n">action</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span></div>
<span class="k">def</span> <span class="nf">get_success_rate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_successes_across_evaluation_episodes</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_evaluation_episodes_completed</span>
@@ -1213,7 +1216,16 @@
<span class="n">savers</span> <span class="o">=</span> <span class="n">SaverCollection</span><span class="p">()</span>
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">savers</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">network</span><span class="o">.</span><span class="n">collect_savers</span><span class="p">(</span><span class="n">parent_path_suffix</span><span class="p">))</span>
<span class="k">return</span> <span class="n">savers</span></div></div>
<span class="k">return</span> <span class="n">savers</span></div>
<span class="k">def</span> <span class="nf">get_current_time</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">pass</span>
<span class="k">return</span> <span class="p">{</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">EpisodeNumber</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span><span class="p">,</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">TrainingIteration</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span><span class="p">,</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">EnvironmentSteps</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span><span class="p">,</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">WallClockTime</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">get_current_wall_clock_time</span><span class="p">(),</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="p">]</span></div>
</pre></div>
</div>
@@ -0,0 +1,387 @@
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>rl_coach.agents.agent_interface &mdash; Reinforcement Learning Coach 0.11.0 documentation</title>
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
<script src="../../../_static/js/modernizr.min.js"></script>
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search">
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Intro</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
</ul>
<p class="caption"><span class="caption-text">Design</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
</ul>
<p class="caption"><span class="caption-text">Contributing</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
</ul>
<p class="caption"><span class="caption-text">Components</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/data_stores/index.html">Data Stores</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memory_backends/index.html">Memory Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/orchestrators/index.html">Orchestrators</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">Reinforcement Learning Coach</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html">Docs</a> &raquo;</li>
<li><a href="../../index.html">Module code</a> &raquo;</li>
<li>rl_coach.agents.agent_interface</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for rl_coach.agents.agent_interface</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">EnvResponse</span><span class="p">,</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">PredictionType</span><span class="p">,</span> <span class="n">ActionType</span><span class="p">,</span> <span class="n">Transition</span>
<span class="kn">from</span> <span class="nn">rl_coach.saver</span> <span class="k">import</span> <span class="n">SaverCollection</span>
<span class="k">class</span> <span class="nc">AgentInterface</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_parent</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">spaces</span> <span class="o">=</span> <span class="kc">None</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">parent</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get the parent class of the agent</span>
<span class="sd"> :return: the current phase</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parent</span>
<span class="nd">@parent</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span> <span class="nf">parent</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Change the parent class of the agent</span>
<span class="sd"> :param val: the new parent</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_parent</span> <span class="o">=</span> <span class="n">val</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">phase</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">RunPhase</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get the phase of the agent</span>
<span class="sd"> :return: the current phase</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_phase</span>
<span class="nd">@phase</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span> <span class="nf">phase</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">RunPhase</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Change the phase of the agent</span>
<span class="sd"> :param val: the new phase</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">=</span> <span class="n">val</span>
<span class="k">def</span> <span class="nf">reset_internal_state</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Reset the episode parameters for the agent</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">List</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Train the agents network</span>
<span class="sd"> :return: The loss of the training</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">act</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActionInfo</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get a decision of the next action to take.</span>
<span class="sd"> The action is dependent on the current state which the agent holds from resetting the environment or from</span>
<span class="sd"> the observe function.</span>
<span class="sd"> :return: A tuple containing the actual action and additional info on the action</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">observe</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">env_response</span><span class="p">:</span> <span class="n">EnvResponse</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets a response from the environment.</span>
<span class="sd"> Processes this information for later use. For example, create a transition and store it in memory.</span>
<span class="sd"> The action info (a class containing any info the agent wants to store regarding its action decision process) is</span>
<span class="sd"> stored by the agent itself when deciding on the action.</span>
<span class="sd"> :param env_response: a EnvResponse containing the response from the environment</span>
<span class="sd"> :return: a done signal which is based on the agent knowledge. This can be different from the done signal from</span>
<span class="sd"> the environment. For example, an agent can decide to finish the episode each time it gets some</span>
<span class="sd"> intrinsic reward</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">save_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">checkpoint_prefix</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Save the model of the agent to the disk. This can contain the network parameters, the memory of the agent, etc.</span>
<span class="sd"> :param checkpoint_prefix: The prefix of the checkpoint file to save</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">prediction_type</span><span class="p">:</span> <span class="n">PredictionType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get a prediction from the agent with regard to the requested prediction_type. If the agent cannot predict this</span>
<span class="sd"> type of prediction_type, or if there is more than possible way to do so, raise a ValueException.</span>
<span class="sd"> :param states:</span>
<span class="sd"> :param prediction_type:</span>
<span class="sd"> :return: the agent&#39;s prediction</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">set_incoming_directive</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Pass a higher level command (directive) to the agent.</span>
<span class="sd"> For example, a higher level agent can set the goal of the agent.</span>
<span class="sd"> :param action: the directive to pass to the agent</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">collect_savers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">parent_path_suffix</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SaverCollection</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Collect all of agent savers</span>
<span class="sd"> :param parent_path_suffix: path suffix of the parent of the agent</span>
<span class="sd"> (could be name of level manager or composite agent)</span>
<span class="sd"> :return: collection of all agent savers</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">handle_episode_ended</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Make any changes needed when each episode is ended.</span>
<span class="sd"> This includes incrementing counters, updating full episode dependent values, updating logs, etc.</span>
<span class="sd"> This function is called right after each episode is ended.</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">run_off_policy_evaluation</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Run off-policy evaluation estimators to evaluate the trained policy performance against a dataset.</span>
<span class="sd"> Should only be implemented for off-policy RL algorithms.</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2018, Intel AI Lab
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>
@@ -264,8 +264,8 @@
<span class="c1"># prediction&#39;s format is (batch,actions,atoms)</span>
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
<span class="n">prediction</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
<span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">prediction</span><span class="p">)</span>
<span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">,</span>
<span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">q_values</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">q_values</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">q_values</span>
@@ -280,11 +280,14 @@
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
<span class="p">])</span>
<span class="c1"># add Q value samples for logging</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">))</span>
<span class="c1"># select the optimal actions for the next state</span>
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">distributional_q_st_plus_1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
<span class="c1"># an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.</span>
<span class="c1"># only 10% speedup overall - leaving commented out as the code is not as clear.</span>
@@ -297,7 +300,7 @@
<span class="c1"># bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])</span>
<span class="c1"># u_ = (np.ceil(bj_)).astype(int)</span>
<span class="c1"># l_ = (np.floor(bj_)).astype(int)</span>
<span class="c1"># m_ = np.zeros((self.ap.network_wrappers[&#39;main&#39;].batch_size, self.z_values.size))</span>
<span class="c1"># m_ = np.zeros((batch.size, self.z_values.size))</span>
<span class="c1"># np.add.at(m_, [batches, l_],</span>
<span class="c1"># np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))</span>
<span class="c1"># np.add.at(m_, [batches, u_],</span>
@@ -387,6 +387,8 @@
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">likelihood_ratio</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">clipped_likelihood_ratio</span><span class="p">]</span>
<span class="c1"># TODO-fixme if batch.size / self.ap.network_wrappers[&#39;main&#39;].batch_size is not an integer, we do not train on</span>
<span class="c1"># some of the data</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)):</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">i</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
<span class="n">end</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
+2 -2
View File
@@ -326,13 +326,13 @@
<span class="n">network_inputs</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)</span>
<span class="n">network_inputs</span><span class="p">[</span><span class="s1">&#39;goal&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_goal</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># get the current outputs of the network</span>
<span class="n">targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">network_inputs</span><span class="p">)</span>
<span class="c1"># change the targets for the taken actions</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
<span class="n">targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">&#39;future_measurements&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">network_inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
+10 -2
View File
@@ -250,6 +250,9 @@
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">&#39;LevelManager&#39;</span><span class="p">,</span> <span class="s1">&#39;CompositeAgent&#39;</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">select_actions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">next_states</span><span class="p">,</span> <span class="n">q_st_plus_1</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">q_st_plus_1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<div class="viewcode-block" id="DQNAgent.learn_from_batch"><a class="viewcode-back" href="../../../test.html#rl_coach.agents.dqn_agent.DQNAgent.learn_from_batch">[docs]</a> <span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
@@ -261,11 +264,16 @@
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
<span class="p">])</span>
<span class="n">selected_actions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">select_actions</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">q_st_plus_1</span><span class="p">)</span>
<span class="c1"># add Q value samples for logging</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">)</span>
<span class="c1"># only update the action that we have actually done in this transition</span>
<span class="n">TD_errors</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
<span class="n">new_target</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span>\
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">q_st_plus_1</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">q_st_plus_1</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">selected_actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
<span class="n">TD_errors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">new_target</span> <span class="o">-</span> <span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]))</span>
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">new_target</span>
+1 -1
View File
@@ -245,7 +245,7 @@
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
<span class="n">one_step_target</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> \
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> \
<span class="n">q_st_plus_1</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">selected_actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
@@ -303,6 +303,9 @@
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="kc">True</span><span class="p">,</span> <span class="s1">&#39;The available values for targets_horizon are: 1-Step, N-Step&#39;</span>
<span class="c1"># add Q value samples for logging</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">state_value_head_targets</span><span class="p">)</span>
<span class="c1"># train</span>
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="p">[</span><span class="n">state_value_head_targets</span><span class="p">])</span>
+2 -2
View File
@@ -313,7 +313,7 @@
<span class="n">TD_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
<span class="n">bootstrapped_return_from_old_policy</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
<span class="c1"># only update the action that we have actually done in this transition</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">bootstrapped_return_from_old_policy</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="c1"># set the gradients to fetch for the DND update</span>
@@ -342,7 +342,7 @@
<span class="n">embedding</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">&#39;main&#39;</span><span class="p">),</span>
<span class="n">outputs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">state_embedding</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">act</span><span class="p">()</span>
+1 -1
View File
@@ -264,7 +264,7 @@
<span class="c1"># calculate TD error</span>
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">q_st_online</span><span class="p">)</span>
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> \
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> \
<span class="n">q_st_plus_1_target</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">selected_actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
@@ -268,11 +268,14 @@
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
<span class="p">])</span>
<span class="c1"># add Q value samples for logging</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_q_values</span><span class="p">(</span><span class="n">current_quantiles</span><span class="p">))</span>
<span class="c1"># get the optimal actions to take for the next states</span>
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_q_values</span><span class="p">(</span><span class="n">next_state_quantiles</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># calculate the Bellman update</span>
<span class="n">batch_idx</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">))</span>
<span class="n">batch_idx</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">(</span><span class="kc">True</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> \
<span class="o">*</span> <span class="n">next_state_quantiles</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">target_actions</span><span class="p">]</span>
@@ -283,9 +286,9 @@
<span class="c1"># calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order</span>
<span class="n">cumulative_probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span><span class="p">)</span> <span class="c1"># tau_i</span>
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="mf">0.5</span><span class="o">*</span><span class="p">(</span><span class="n">cumulative_probabilities</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="n">cumulative_probabilities</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># tau^hat_i</span>
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">quantile_midpoints</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">quantile_midpoints</span><span class="p">,</span> <span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">sorted_quantiles</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">current_quantiles</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()])</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
<span class="n">quantile_midpoints</span><span class="p">[</span><span class="n">idx</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">quantile_midpoints</span><span class="p">[</span><span class="n">idx</span><span class="p">,</span> <span class="n">sorted_quantiles</span><span class="p">[</span><span class="n">idx</span><span class="p">]]</span>
<span class="c1"># train</span>
@@ -240,9 +240,12 @@
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">algorithm</span> <span class="o">=</span> <span class="n">RainbowDQNAlgorithmParameters</span><span class="p">()</span>
<span class="c1"># ParameterNoiseParameters is changing the network wrapper parameters. This line needs to be done first.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;main&quot;</span><span class="p">:</span> <span class="n">RainbowDQNNetworkParameters</span><span class="p">()}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span> <span class="o">=</span> <span class="n">ParameterNoiseParameters</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="o">=</span> <span class="n">PrioritizedExperienceReplayParameters</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;main&quot;</span><span class="p">:</span> <span class="n">RainbowDQNNetworkParameters</span><span class="p">()}</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
@@ -275,11 +278,14 @@
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
<span class="p">])</span>
<span class="c1"># add Q value samples for logging</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">))</span>
<span class="c1"># only update the action that we have actually done in this transition (using the Double-DQN selected actions)</span>
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">ddqn_selected_actions</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
<span class="c1"># we use batch.info(&#39;should_bootstrap_next_state&#39;) instead of (1 - batch.game_overs()) since with n-step,</span>
<span class="c1"># we will not bootstrap for the last n-step transitions in the episode</span>
@@ -0,0 +1,554 @@
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>rl_coach.agents.soft_actor_critic_agent &mdash; Reinforcement Learning Coach 0.11.0 documentation</title>
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
<script src="../../../_static/js/modernizr.min.js"></script>
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search">
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Intro</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
</ul>
<p class="caption"><span class="caption-text">Design</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
</ul>
<p class="caption"><span class="caption-text">Contributing</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
</ul>
<p class="caption"><span class="caption-text">Components</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/data_stores/index.html">Data Stores</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memory_backends/index.html">Memory Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/orchestrators/index.html">Orchestrators</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">Reinforcement Learning Coach</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html">Docs</a> &raquo;</li>
<li><a href="../../index.html">Module code</a> &raquo;</li>
<li>rl_coach.agents.soft_actor_critic_agent</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for rl_coach.agents.soft_actor_critic_agent</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1"># Copyright (c) 2019 Intel Corporation</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">copy</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
<span class="kn">from</span> <span class="nn">rl_coach.agents.policy_optimization_agent</span> <span class="k">import</span> <span class="n">PolicyOptimizationAgent</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">SACQHeadParameters</span><span class="p">,</span><span class="n">SACPolicyHeadParameters</span><span class="p">,</span><span class="n">VHeadParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AgentParameters</span><span class="p">,</span> <span class="n">EmbedderScheme</span><span class="p">,</span> <span class="n">MiddlewareScheme</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">RunPhase</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.additive_noise</span> <span class="k">import</span> <span class="n">AdditiveNoiseParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.experience_replay</span> <span class="k">import</span> <span class="n">ExperienceReplayParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">BoxActionSpace</span>
<span class="c1"># There are 3 networks in SAC implementation. All have the same topology but parameters are not shared.</span>
<span class="c1"># The networks are:</span>
<span class="c1"># 1. State Value Network - SACValueNetwork</span>
<span class="c1"># 2. Soft Q Value Network - SACCriticNetwork</span>
<span class="c1"># 3. Policy Network - SACPolicyNetwork - currently supporting only Gaussian Policy</span>
<span class="c1"># 1. State Value Network - SACValueNetwork</span>
<span class="c1"># this is the state value network in SAC.</span>
<span class="c1"># The network is trained to predict (regression) the state value in the max-entropy settings</span>
<span class="c1"># The objective to be minimized is given in equation (5) in the paper:</span>
<span class="c1">#</span>
<span class="c1"># J(psi)= E_(s~D)[0.5*(V_psi(s)-y(s))^2]</span>
<span class="c1"># where y(s) = E_(a~pi)[Q_theta(s,a)-log(pi(a|s))]</span>
<span class="c1"># Default parameters for value network:</span>
<span class="c1"># topology :</span>
<span class="c1"># input embedder : EmbedderScheme.Medium (Dense(256)) , relu activation</span>
<span class="c1"># middleware : EmbedderScheme.Medium (Dense(256)) , relu activation</span>
<span class="k">class</span> <span class="nc">SACValueNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">)}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">VHeadParameters</span><span class="p">(</span><span class="n">initializer</span><span class="o">=</span><span class="s1">&#39;xavier&#39;</span><span class="p">)]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rescale_gradient_from_head_by_factor</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0003</span> <span class="c1"># 3e-4 see appendix D in the paper</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># tau is set in SoftActorCriticAlgorithmParameters.rate_for_copying_weights_to_target</span>
<span class="c1"># 2. Soft Q Value Network - SACCriticNetwork</span>
<span class="c1"># the whole network is built in the SACQHeadParameters. we use empty input embedder and middleware</span>
<span class="k">class</span> <span class="nc">SACCriticNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">EmbedderScheme</span><span class="o">.</span><span class="n">Empty</span><span class="p">)}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">MiddlewareScheme</span><span class="o">.</span><span class="n">Empty</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">SACQHeadParameters</span><span class="p">()]</span> <span class="c1"># SACQHeadParameters includes the topology of the head</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rescale_gradient_from_head_by_factor</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0003</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># 3. policy Network</span>
<span class="c1"># Default parameters for policy network:</span>
<span class="c1"># topology :</span>
<span class="c1"># input embedder : EmbedderScheme.Medium (Dense(256)) , relu activation</span>
<span class="c1"># middleware : EmbedderScheme = [Dense(256)] , relu activation --&gt; scheme should be overridden in preset</span>
<span class="k">class</span> <span class="nc">SACPolicyNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">)}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">SACPolicyHeadParameters</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rescale_gradient_from_head_by_factor</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0003</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l2_regularization</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># weight decay regularization. not used in the original paper</span>
<span class="c1"># Algorithm Parameters</span>
<div class="viewcode-block" id="SoftActorCriticAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/sac.html#rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">SoftActorCriticAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param num_steps_between_copying_online_weights_to_target: (StepMethod)</span>
<span class="sd"> The number of steps between copying the online network weights to the target network weights.</span>
<span class="sd"> :param rate_for_copying_weights_to_target: (float)</span>
<span class="sd"> When copying the online network weights to the target network weights, a soft update will be used, which</span>
<span class="sd"> weight the new online network weights by rate_for_copying_weights_to_target. (Tau as defined in the paper)</span>
<span class="sd"> :param use_deterministic_for_evaluation: (bool)</span>
<span class="sd"> If True, during the evaluation phase, action are chosen deterministically according to the policy mean</span>
<span class="sd"> and not sampled from the policy distribution.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_copying_online_weights_to_target</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.005</span>
<span class="bp">self</span><span class="o">.</span><span class="n">use_deterministic_for_evaluation</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># evaluate agent using deterministic policy (i.e. take the mean value)</span></div>
<span class="k">class</span> <span class="nc">SoftActorCriticAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">SoftActorCriticAlgorithmParameters</span><span class="p">(),</span>
<span class="n">exploration</span><span class="o">=</span><span class="n">AdditiveNoiseParameters</span><span class="p">(),</span>
<span class="n">memory</span><span class="o">=</span><span class="n">ExperienceReplayParameters</span><span class="p">(),</span> <span class="c1"># SAC doesnt use episodic related data</span>
<span class="c1"># network wrappers:</span>
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([(</span><span class="s2">&quot;policy&quot;</span><span class="p">,</span> <span class="n">SACPolicyNetworkParameters</span><span class="p">()),</span>
<span class="p">(</span><span class="s2">&quot;q&quot;</span><span class="p">,</span> <span class="n">SACCriticNetworkParameters</span><span class="p">()),</span>
<span class="p">(</span><span class="s2">&quot;v&quot;</span><span class="p">,</span> <span class="n">SACValueNetworkParameters</span><span class="p">())]))</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="s1">&#39;rl_coach.agents.soft_actor_critic_agent:SoftActorCriticAgent&#39;</span>
<span class="c1"># Soft Actor Critic - https://arxiv.org/abs/1801.01290</span>
<span class="k">class</span> <span class="nc">SoftActorCriticAgent</span><span class="p">(</span><span class="n">PolicyOptimizationAgent</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">&#39;LevelManager&#39;</span><span class="p">,</span> <span class="s1">&#39;CompositeAgent&#39;</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_gradient_update_step_idx</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1"># register signals to track (in learn_from_batch)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_means</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Policy_mu_avg&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_logsig</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Policy_logsig&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_logprob_sampled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Policy_logp_sampled&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_grads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;Policy_grads_sumabs&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q1_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">&quot;Q1&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">TD_err1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">&quot;TD err1&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q2_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">&quot;Q2&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">TD_err2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">&quot;TD err2&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">v_tgt_ns</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;V_tgt_ns&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">v_onl_ys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">&#39;V_onl_ys&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action_signal</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">&quot;actions&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="c1">#########################################</span>
<span class="c1"># need to update the following networks:</span>
<span class="c1"># 1. actor (policy)</span>
<span class="c1"># 2. state value (v)</span>
<span class="c1"># 3. critic (q1 and q2)</span>
<span class="c1"># 4. target network - probably already handled by V</span>
<span class="c1">#########################################</span>
<span class="c1"># define the networks to be used</span>
<span class="c1"># State Value Network</span>
<span class="n">value_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;v&#39;</span><span class="p">]</span>
<span class="n">value_network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;v&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="c1"># Critic Network</span>
<span class="n">q_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;q&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
<span class="n">q_head</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">q_network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;q&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="c1"># Actor (policy) Network</span>
<span class="n">policy_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;policy&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
<span class="n">policy_network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;policy&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="c1">##########################################</span>
<span class="c1"># 1. updating the actor - according to (13) in the paper</span>
<span class="n">policy_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">policy_network_keys</span><span class="p">))</span>
<span class="n">policy_results</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">policy_inputs</span><span class="p">)</span>
<span class="n">policy_mu</span><span class="p">,</span> <span class="n">policy_std</span><span class="p">,</span> <span class="n">sampled_raw_actions</span><span class="p">,</span> <span class="n">sampled_actions</span><span class="p">,</span> <span class="n">sampled_actions_logprob</span><span class="p">,</span> \
<span class="n">sampled_actions_logprob_mean</span> <span class="o">=</span> <span class="n">policy_results</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_means</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">policy_mu</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_logsig</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">policy_std</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_logprob_sampled</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">sampled_actions_logprob_mean</span><span class="p">)</span>
<span class="c1"># get the state-action values for the replayed states and their corresponding actions from the policy</span>
<span class="n">q_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">q_network_keys</span><span class="p">))</span>
<span class="n">q_inputs</span><span class="p">[</span><span class="s1">&#39;output_0_0&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">sampled_actions</span>
<span class="n">log_target</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">q_inputs</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="c1"># log internal q values</span>
<span class="n">q1_vals</span><span class="p">,</span> <span class="n">q2_vals</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">q_inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">q_head</span><span class="o">.</span><span class="n">q1_output</span><span class="p">,</span> <span class="n">q_head</span><span class="o">.</span><span class="n">q2_output</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q1_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q1_vals</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q2_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q2_vals</span><span class="p">)</span>
<span class="c1"># calculate the gradients according to (13)</span>
<span class="c1"># get the gradients of log_prob w.r.t the weights (parameters) - indicated as phi in the paper</span>
<span class="n">initial_feed_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">policy_network</span><span class="o">.</span><span class="n">gradients_weights_ph</span><span class="p">[</span><span class="mi">5</span><span class="p">]:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)}</span>
<span class="n">dlogp_dphi</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">policy_inputs</span><span class="p">,</span>
<span class="n">outputs</span><span class="o">=</span><span class="n">policy_network</span><span class="o">.</span><span class="n">weighted_gradients</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span>
<span class="n">initial_feed_dict</span><span class="o">=</span><span class="n">initial_feed_dict</span><span class="p">)</span>
<span class="c1"># calculate dq_da</span>
<span class="n">dq_da</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">q_inputs</span><span class="p">,</span>
<span class="n">outputs</span><span class="o">=</span><span class="n">q_network</span><span class="o">.</span><span class="n">gradients_wrt_inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="s1">&#39;output_0_0&#39;</span><span class="p">])</span>
<span class="c1"># calculate da_dphi</span>
<span class="n">initial_feed_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">policy_network</span><span class="o">.</span><span class="n">gradients_weights_ph</span><span class="p">[</span><span class="mi">3</span><span class="p">]:</span> <span class="n">dq_da</span><span class="p">}</span>
<span class="n">dq_dphi</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">policy_inputs</span><span class="p">,</span>
<span class="n">outputs</span><span class="o">=</span><span class="n">policy_network</span><span class="o">.</span><span class="n">weighted_gradients</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span>
<span class="n">initial_feed_dict</span><span class="o">=</span><span class="n">initial_feed_dict</span><span class="p">)</span>
<span class="c1"># now given dlogp_dphi, dq_dphi we need to calculate the policy gradients according to (13)</span>
<span class="n">policy_grads</span> <span class="o">=</span> <span class="p">[</span><span class="n">dlogp_dphi</span><span class="p">[</span><span class="n">l</span><span class="p">]</span> <span class="o">-</span> <span class="n">dq_dphi</span><span class="p">[</span><span class="n">l</span><span class="p">]</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dlogp_dphi</span><span class="p">))]</span>
<span class="c1"># apply the gradients to policy networks</span>
<span class="n">policy_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">policy_grads</span><span class="p">)</span>
<span class="n">grads_sumabs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">([</span><span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">policy_grads</span><span class="p">[</span><span class="n">l</span><span class="p">]))</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">policy_grads</span><span class="p">))])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_grads</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">grads_sumabs</span><span class="p">)</span>
<span class="c1">##########################################</span>
<span class="c1"># 2. updating the state value online network weights</span>
<span class="c1"># done by calculating the targets for the v head according to (5) in the paper</span>
<span class="c1"># value_targets = log_targets-sampled_actions_logprob</span>
<span class="n">value_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">value_network_keys</span><span class="p">))</span>
<span class="n">value_targets</span> <span class="o">=</span> <span class="n">log_target</span> <span class="o">-</span> <span class="n">sampled_actions_logprob</span>
<span class="bp">self</span><span class="o">.</span><span class="n">v_onl_ys</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">value_targets</span><span class="p">)</span>
<span class="c1"># call value_network apply gradients with this target</span>
<span class="n">value_loss</span> <span class="o">=</span> <span class="n">value_network</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">train_on_batch</span><span class="p">(</span><span class="n">value_inputs</span><span class="p">,</span> <span class="n">value_targets</span><span class="p">[:,</span><span class="kc">None</span><span class="p">])[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1">##########################################</span>
<span class="c1"># 3. updating the critic (q networks)</span>
<span class="c1"># updating q networks according to (7) in the paper</span>
<span class="c1"># define the input to the q network: state has been already updated previously, but now we need</span>
<span class="c1"># the actions from the batch (and not those sampled by the policy)</span>
<span class="n">q_inputs</span><span class="p">[</span><span class="s1">&#39;output_0_0&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># define the targets : scale_reward * reward + (1-terminal)*discount*v_target_next_state</span>
<span class="c1"># define v_target_next_state</span>
<span class="n">value_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">value_network_keys</span><span class="p">))</span>
<span class="n">v_target_next_state</span> <span class="o">=</span> <span class="n">value_network</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">value_inputs</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">v_tgt_ns</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">v_target_next_state</span><span class="p">)</span>
<span class="c1"># Note: reward is assumed to be rescaled by RewardRescaleFilter in the preset parameters</span>
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> \
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">v_target_next_state</span>
<span class="c1"># call critic network update</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">train_on_batch</span><span class="p">(</span><span class="n">q_inputs</span><span class="p">,</span> <span class="n">TD_targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="p">[</span><span class="n">q_head</span><span class="o">.</span><span class="n">q1_loss</span><span class="p">,</span> <span class="n">q_head</span><span class="o">.</span><span class="n">q2_loss</span><span class="p">])</span>
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
<span class="n">q1_loss</span><span class="p">,</span> <span class="n">q2_loss</span> <span class="o">=</span> <span class="n">result</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">TD_err1</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q1_loss</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">TD_err2</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q2_loss</span><span class="p">)</span>
<span class="c1">##########################################</span>
<span class="c1"># 4. updating the value target network</span>
<span class="c1"># I just need to set the parameter rate_for_copying_weights_to_target in the agent parameters to be 1-tau</span>
<span class="c1"># where tau is the hyper parameter as defined in sac original implementation</span>
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> get the mean and stdev of the policy distribution given &#39;states&#39;</span>
<span class="sd"> :param states: the states for which we need to sample actions from the policy</span>
<span class="sd"> :return: mean and stdev</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">&#39;policy&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;policy&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="c1"># since the algorithm works with experience replay buffer (non-episodic),</span>
<span class="c1"># we cant use the policy optimization train method. we need Agent.train</span>
<span class="c1"># note that since in Agent.train there is no apply_gradients, we need to do it in learn from batch</span>
<span class="k">return</span> <span class="n">Agent</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> choose_action - chooses the most likely action</span>
<span class="sd"> if &#39;deterministic&#39; - take the mean of the policy which is the prediction of the policy network.</span>
<span class="sd"> else - use the exploration policy</span>
<span class="sd"> :param curr_state:</span>
<span class="sd"> :return: action wrapped in ActionInfo</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;SAC works only for continuous control problems&quot;</span><span class="p">)</span>
<span class="c1"># convert to batch so we can run it through the network</span>
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">&#39;policy&#39;</span><span class="p">)</span>
<span class="c1"># use the online network for prediction</span>
<span class="n">policy_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;policy&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
<span class="n">policy_head</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">,</span>
<span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">policy_head</span><span class="o">.</span><span class="n">policy_mean</span><span class="p">,</span> <span class="n">policy_head</span><span class="o">.</span><span class="n">actions</span><span class="p">])</span>
<span class="n">action_mean</span><span class="p">,</span> <span class="n">action_sample</span> <span class="o">=</span> <span class="n">result</span>
<span class="c1"># if using deterministic policy, take the mean values. else, use exploration policy to sample from the pdf</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">use_deterministic_for_evaluation</span><span class="p">:</span>
<span class="n">action</span> <span class="o">=</span> <span class="n">action_mean</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">action</span> <span class="o">=</span> <span class="n">action_sample</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action_signal</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">)</span>
<span class="k">return</span> <span class="n">action_info</span>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2018, Intel AI Lab
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>
@@ -193,16 +193,17 @@
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">StateType</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">Batch</span>
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.prioritized_experience_replay</span> <span class="k">import</span> <span class="n">PrioritizedExperienceReplay</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span>
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">deepcopy</span>
<span class="c1">## This is an abstract agent - there is no learn_from_batch method ##</span>
@@ -229,8 +230,9 @@
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">actions_q_values</span>
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">&#39;main&#39;</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">&#39;main&#39;</span><span class="p">),</span>
<span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">update_transition_priorities_and_get_weights</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">TD_errors</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="c1"># update errors in prioritized replay buffer</span>
@@ -259,10 +261,12 @@
<span class="c1"># this is for bootstrapped dqn</span>
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span> <span class="o">==</span> <span class="nb">list</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">last_action_values</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="c1"># store the q values statistics for logging</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">q_value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q_value</span><span class="p">)</span>
@@ -276,6 +280,77 @@
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;ValueOptimizationAgent is an abstract agent. Not to be used directly.&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">run_off_policy_evaluation</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on</span>
<span class="sd"> an evaluation dataset, which was collected by another policy(ies).</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">ope_manager</span>
<span class="n">dataset_as_episodes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;get_all_complete_episodes_from_to&#39;</span><span class="p">,</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;get_last_training_set_episode_id&#39;</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;num_complete_episodes&#39;</span><span class="p">)))</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataset_as_episodes</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;train_to_eval_ratio is too high causing the evaluation set to be empty. &#39;</span>
<span class="s1">&#39;Consider decreasing its value.&#39;</span><span class="p">)</span>
<span class="n">ips</span><span class="p">,</span> <span class="n">dm</span><span class="p">,</span> <span class="n">dr</span><span class="p">,</span> <span class="n">seq_dr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ope_manager</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span>
<span class="n">dataset_as_episodes</span><span class="o">=</span><span class="n">dataset_as_episodes</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">discount_factor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span><span class="p">,</span>
<span class="n">reward_model</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;reward_model&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span>
<span class="n">q_network</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span>
<span class="n">network_keys</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
<span class="c1"># get the estimators out to the screen</span>
<span class="n">log</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span>
<span class="n">log</span><span class="p">[</span><span class="s1">&#39;Epoch&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span>
<span class="n">log</span><span class="p">[</span><span class="s1">&#39;IPS&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">ips</span>
<span class="n">log</span><span class="p">[</span><span class="s1">&#39;DM&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">dm</span>
<span class="n">log</span><span class="p">[</span><span class="s1">&#39;DR&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">dr</span>
<span class="n">log</span><span class="p">[</span><span class="s1">&#39;Sequential-DR&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq_dr</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_dict</span><span class="p">(</span><span class="n">log</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="s1">&#39;Off-Policy Evaluation&#39;</span><span class="p">)</span>
<span class="c1"># get the estimators out to dashboard</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_current_time</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Inverse Propensity Score&#39;</span><span class="p">,</span> <span class="n">ips</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Direct Method Reward&#39;</span><span class="p">,</span> <span class="n">dm</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Doubly Robust&#39;</span><span class="p">,</span> <span class="n">dr</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">&#39;Sequential Doubly Robust&#39;</span><span class="p">,</span> <span class="n">seq_dr</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_reward_model_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span><span class="p">):</span>
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;reward_model&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="n">current_rewards_prediction_for_all_actions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;reward_model&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
<span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
<span class="n">current_rewards_prediction_for_all_actions</span><span class="p">[</span><span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;reward_model&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span>
<span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">current_rewards_prediction_for_all_actions</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">improve_reward_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Train a reward model to be used by the doubly-robust estimator</span>
<span class="sd"> :param epochs: The total number of epochs to use for training a reward model</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;reward_model&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
<span class="c1"># this is fitted from the training dataset</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
<span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">total_transitions_processed</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;get_shuffled_data_generator&#39;</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)):</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_reward_model_loss</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="n">total_transitions_processed</span> <span class="o">+=</span> <span class="n">batch</span><span class="o">.</span><span class="n">size</span>
<span class="n">log</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span>
<span class="n">log</span><span class="p">[</span><span class="s1">&#39;Epoch&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">epoch</span>
<span class="n">log</span><span class="p">[</span><span class="s1">&#39;loss&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">/</span> <span class="n">total_transitions_processed</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_dict</span><span class="p">(</span><span class="n">log</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="s1">&#39;Training Reward Model&#39;</span><span class="p">)</span>
</pre></div>
</div>
+36 -20
View File
@@ -205,6 +205,7 @@
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">TrainingSteps</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">GradientClippingMethod</span><span class="p">,</span> <span class="n">RunPhase</span><span class="p">,</span> \
<span class="n">SelectedPhaseOnlyDumpFilter</span><span class="p">,</span> <span class="n">MaxDumpFilter</span>
<span class="kn">from</span> <span class="nn">rl_coach.filters.filter</span> <span class="k">import</span> <span class="n">NoInputFilter</span>
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
<span class="k">class</span> <span class="nc">Frameworks</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
@@ -379,9 +380,6 @@
<span class="c1"># distributed agents params</span>
<span class="bp">self</span><span class="o">.</span><span class="n">share_statistics_between_workers</span> <span class="o">=</span> <span class="kc">True</span>
<span class="c1"># intrinsic reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">scale_external_reward_by_intrinsic_reward_value</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># n-step returns</span>
<span class="bp">self</span><span class="o">.</span><span class="n">n_step</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span> <span class="c1"># calculate the total return (no bootstrap, by default)</span>
@@ -470,7 +468,8 @@
<span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">replace_mse_with_huber_loss</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">create_target_network</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">tensorflow_support</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="n">tensorflow_support</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">softmax_temperature</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param force_cpu:</span>
<span class="sd"> Force the neural networks to run on the CPU even if a GPU is available</span>
@@ -553,6 +552,8 @@
<span class="sd"> online network at will.</span>
<span class="sd"> :param tensorflow_support:</span>
<span class="sd"> A flag which specifies if the network is supported by the TensorFlow framework.</span>
<span class="sd"> :param softmax_temperature:</span>
<span class="sd"> If a softmax is present in the network head output, use this temperature</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">framework</span> <span class="o">=</span> <span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span>
@@ -583,16 +584,19 @@
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="n">heads_parameters</span>
<span class="bp">self</span><span class="o">.</span><span class="n">use_separate_networks_per_head</span> <span class="o">=</span> <span class="n">use_separate_networks_per_head</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="n">optimizer_type</span>
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="n">replace_mse_with_huber_loss</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="n">create_target_network</span>
<span class="c1"># Framework support</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tensorflow_support</span> <span class="o">=</span> <span class="n">tensorflow_support</span>
<span class="c1"># Hyper-Parameter values</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_epsilon</span> <span class="o">=</span> <span class="n">optimizer_epsilon</span>
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta1</span> <span class="o">=</span> <span class="n">adam_optimizer_beta1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta2</span> <span class="o">=</span> <span class="n">adam_optimizer_beta2</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rms_prop_optimizer_decay</span> <span class="o">=</span> <span class="n">rms_prop_optimizer_decay</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="n">replace_mse_with_huber_loss</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="n">create_target_network</span>
<span class="c1"># Framework support</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tensorflow_support</span> <span class="o">=</span> <span class="n">tensorflow_support</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">softmax_temperature</span> <span class="o">=</span> <span class="n">softmax_temperature</span></div>
<span class="k">class</span> <span class="nc">NetworkComponentParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
@@ -723,6 +727,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_a_lowest_level_agent</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">task_parameters</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_batch_rl_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
@@ -730,18 +735,22 @@
<div class="viewcode-block" id="TaskParameters"><a class="viewcode-back" href="../../components/additional_parameters.html#rl_coach.base_parameters.TaskParameters">[docs]</a><span class="k">class</span> <span class="nc">TaskParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">framework_type</span><span class="p">:</span> <span class="n">Frameworks</span><span class="o">=</span><span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">use_cpu</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">framework_type</span><span class="p">:</span> <span class="n">Frameworks</span><span class="o">=</span><span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">use_cpu</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">experiment_path</span><span class="o">=</span><span class="s1">&#39;/tmp&#39;</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_restore_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">export_onnx_graph</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">apply_stop_condition</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">num_gpu</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
<span class="n">checkpoint_restore_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">export_onnx_graph</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">apply_stop_condition</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">num_gpu</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param framework_type: deep learning framework type. currently only tensorflow is supported</span>
<span class="sd"> :param evaluate_only: the task will be used only for evaluating the model</span>
<span class="sd"> :param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.</span>
<span class="sd"> A value of 0 means that task will be evaluated for an infinite number of steps.</span>
<span class="sd"> :param use_cpu: use the cpu for this task</span>
<span class="sd"> :param experiment_path: the path to the directory which will store all the experiment outputs</span>
<span class="sd"> :param seed: a seed to use for the random numbers generator</span>
<span class="sd"> :param checkpoint_save_secs: the number of seconds between each checkpoint saving</span>
<span class="sd"> :param checkpoint_restore_dir: the directory to restore the checkpoints from</span>
<span class="sd"> :param checkpoint_restore_dir:</span>
<span class="sd"> [DEPECRATED - will be removed in one of the next releases - switch to checkpoint_restore_path]</span>
<span class="sd"> the dir to restore the checkpoints from</span>
<span class="sd"> :param checkpoint_restore_path: the path to restore the checkpoints from</span>
<span class="sd"> :param checkpoint_save_dir: the directory to store the checkpoints in</span>
<span class="sd"> :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved</span>
<span class="sd"> :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate</span>
@@ -753,7 +762,13 @@
<span class="bp">self</span><span class="o">.</span><span class="n">use_cpu</span> <span class="o">=</span> <span class="n">use_cpu</span>
<span class="bp">self</span><span class="o">.</span><span class="n">experiment_path</span> <span class="o">=</span> <span class="n">experiment_path</span>
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_save_secs</span> <span class="o">=</span> <span class="n">checkpoint_save_secs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_restore_dir</span> <span class="o">=</span> <span class="n">checkpoint_restore_dir</span>
<span class="k">if</span> <span class="n">checkpoint_restore_dir</span><span class="p">:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">&#39;TaskParameters.checkpoint_restore_dir is DEPECRATED and will be removed in one of the next &#39;</span>
<span class="s1">&#39;releases. Please switch to using TaskParameters.checkpoint_restore_path, with your &#39;</span>
<span class="s1">&#39;directory path. &#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_restore_path</span> <span class="o">=</span> <span class="n">checkpoint_restore_dir</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_restore_path</span> <span class="o">=</span> <span class="n">checkpoint_restore_path</span>
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_save_dir</span> <span class="o">=</span> <span class="n">checkpoint_save_dir</span>
<span class="bp">self</span><span class="o">.</span><span class="n">seed</span> <span class="o">=</span> <span class="n">seed</span>
<span class="bp">self</span><span class="o">.</span><span class="n">export_onnx_graph</span> <span class="o">=</span> <span class="n">export_onnx_graph</span>
@@ -763,13 +778,14 @@
<div class="viewcode-block" id="DistributedTaskParameters"><a class="viewcode-back" href="../../components/additional_parameters.html#rl_coach.base_parameters.DistributedTaskParameters">[docs]</a><span class="k">class</span> <span class="nc">DistributedTaskParameters</span><span class="p">(</span><span class="n">TaskParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">framework_type</span><span class="p">:</span> <span class="n">Frameworks</span><span class="p">,</span> <span class="n">parameters_server_hosts</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">worker_hosts</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">job_type</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">task_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">num_tasks</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">task_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_tasks</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">num_training_tasks</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">use_cpu</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">experiment_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dnd</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">shared_memory_scratchpad</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_restore_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">shared_memory_scratchpad</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_restore_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">export_onnx_graph</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">apply_stop_condition</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param framework_type: deep learning framework type. currently only tensorflow is supported</span>
<span class="sd"> :param evaluate_only: the task will be used only for evaluating the model</span>
<span class="sd"> :param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.</span>
<span class="sd"> A value of 0 means that task will be evaluated for an infinite number of steps.</span>
<span class="sd"> :param parameters_server_hosts: comma-separated list of hostname:port pairs to which the parameter servers are</span>
<span class="sd"> assigned</span>
<span class="sd"> :param worker_hosts: comma-separated list of hostname:port pairs to which the workers are assigned</span>
@@ -782,7 +798,7 @@
<span class="sd"> :param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.</span>
<span class="sd"> :param seed: a seed to use for the random numbers generator</span>
<span class="sd"> :param checkpoint_save_secs: the number of seconds between each checkpoint saving</span>
<span class="sd"> :param checkpoint_restore_dir: the directory to restore the checkpoints from</span>
<span class="sd"> :param checkpoint_restore_path: the path to restore the checkpoints from</span>
<span class="sd"> :param checkpoint_save_dir: the directory to store the checkpoints in</span>
<span class="sd"> :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved</span>
<span class="sd"> :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate</span>
@@ -790,7 +806,7 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">framework_type</span><span class="o">=</span><span class="n">framework_type</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="o">=</span><span class="n">evaluate_only</span><span class="p">,</span> <span class="n">use_cpu</span><span class="o">=</span><span class="n">use_cpu</span><span class="p">,</span>
<span class="n">experiment_path</span><span class="o">=</span><span class="n">experiment_path</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="n">checkpoint_save_secs</span><span class="p">,</span>
<span class="n">checkpoint_restore_dir</span><span class="o">=</span><span class="n">checkpoint_restore_dir</span><span class="p">,</span> <span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="n">checkpoint_save_dir</span><span class="p">,</span>
<span class="n">checkpoint_restore_path</span><span class="o">=</span><span class="n">checkpoint_restore_path</span><span class="p">,</span> <span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="n">checkpoint_save_dir</span><span class="p">,</span>
<span class="n">export_onnx_graph</span><span class="o">=</span><span class="n">export_onnx_graph</span><span class="p">,</span> <span class="n">apply_stop_condition</span><span class="o">=</span><span class="n">apply_stop_condition</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">parameters_server_hosts</span> <span class="o">=</span> <span class="n">parameters_server_hosts</span>
<span class="bp">self</span><span class="o">.</span><span class="n">worker_hosts</span> <span class="o">=</span> <span class="n">worker_hosts</span>
+62 -7
View File
@@ -193,9 +193,10 @@
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">namedtuple</span>
<span class="kn">import</span> <span class="nn">copy</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">from</span> <span class="nn">enum</span> <span class="k">import</span> <span class="n">Enum</span>
<span class="kn">from</span> <span class="nn">random</span> <span class="k">import</span> <span class="n">shuffle</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Type</span>
@@ -218,6 +219,17 @@
<span class="n">Measurements</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">Record</span> <span class="o">=</span> <span class="n">namedtuple</span><span class="p">(</span><span class="s1">&#39;Record&#39;</span><span class="p">,</span> <span class="p">[</span><span class="s1">&#39;name&#39;</span><span class="p">,</span> <span class="s1">&#39;label&#39;</span><span class="p">])</span>
<span class="k">class</span> <span class="nc">TimeTypes</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
<span class="n">EpisodeNumber</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;Episode #&#39;</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Episode #&#39;</span><span class="p">)</span>
<span class="n">TrainingIteration</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;Training Iter&#39;</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Training Iteration&#39;</span><span class="p">)</span>
<span class="n">EnvironmentSteps</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;Total steps&#39;</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Total steps (per worker)&#39;</span><span class="p">)</span>
<span class="n">WallClockTime</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;Wall-Clock Time&#39;</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Wall-Clock Time (minutes)&#39;</span><span class="p">)</span>
<span class="n">Epoch</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;Epoch&#39;</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Epoch #&#39;</span><span class="p">)</span>
<span class="c1"># step methods</span>
<span class="k">class</span> <span class="nc">StepMethod</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
@@ -232,6 +244,37 @@
<span class="k">def</span> <span class="nf">num_steps</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_num_steps</span> <span class="o">=</span> <span class="n">val</span>
<span class="k">def</span> <span class="nf">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">==</span> <span class="n">other</span><span class="o">.</span><span class="n">num_steps</span>
<span class="k">def</span> <span class="nf">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> divide this step method with other. If other is an integer, returns an object of the same</span>
<span class="sd"> type as self. If other is the same type of self, returns an integer. In either case, any</span>
<span class="sd"> floating point value is rounded up under the assumption that if we are dividing Steps, we</span>
<span class="sd"> would rather overestimate than underestimate.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)):</span>
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">/</span> <span class="n">other</span><span class="o">.</span><span class="n">num_steps</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)(</span><span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">/</span> <span class="n">other</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;cannot divide </span><span class="si">{}</span><span class="s2"> by </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="nb">type</span><span class="p">(</span><span class="n">other</span><span class="p">)))</span>
<span class="k">def</span> <span class="nf">__rtruediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> divide this step method with other. If other is an integer, returns an object of the same</span>
<span class="sd"> type as self. If other is the same type of self, returns an integer. In either case, any</span>
<span class="sd"> floating point value is rounded up under the assumption that if we are dividing Steps, we</span>
<span class="sd"> would rather overestimate than underestimate.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)):</span>
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">other</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)(</span><span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">other</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;cannot divide </span><span class="si">{}</span><span class="s2"> by </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">other</span><span class="p">),</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)))</span>
<span class="k">class</span> <span class="nc">Frames</span><span class="p">(</span><span class="n">StepMethod</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_steps</span><span class="p">):</span>
@@ -429,6 +472,9 @@
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="o">.</span><span class="n">keys</span><span class="p">(),</span> <span class="n">new_info</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">new_info</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">update_info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_info</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">new_info</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__copy__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">new_transition</span> <span class="o">=</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)()</span>
<span class="n">new_transition</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
@@ -510,8 +556,7 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">,</span> <span class="n">all_action_probabilities</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">state_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">max_action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">action_intrinsic_reward</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="n">action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">state_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">max_action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param action: the action</span>
<span class="sd"> :param all_action_probabilities: the probability that the action was given when selecting it</span>
@@ -520,8 +565,6 @@
<span class="sd"> :param max_action_value: in case this is an action that was selected randomly, this is the value of the action</span>
<span class="sd"> that received the maximum value. if no value is given, the action is assumed to be the</span>
<span class="sd"> action with the maximum value</span>
<span class="sd"> :param action_intrinsic_reward: can contain any intrinsic reward that the agent wants to add to this action</span>
<span class="sd"> selection</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action</span> <span class="o">=</span> <span class="n">action</span>
<span class="bp">self</span><span class="o">.</span><span class="n">all_action_probabilities</span> <span class="o">=</span> <span class="n">all_action_probabilities</span>
@@ -530,8 +573,7 @@
<span class="k">if</span> <span class="ow">not</span> <span class="n">max_action_value</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_action_value</span> <span class="o">=</span> <span class="n">action_value</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_action_value</span> <span class="o">=</span> <span class="n">max_action_value</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action_intrinsic_reward</span> <span class="o">=</span> <span class="n">action_intrinsic_reward</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">max_action_value</span> <span class="o">=</span> <span class="n">max_action_value</span></div>
<div class="viewcode-block" id="Batch"><a class="viewcode-back" href="../../components/core_types.html#rl_coach.core_types.Batch">[docs]</a><span class="k">class</span> <span class="nc">Batch</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
@@ -1047,6 +1089,19 @@
<span class="k">return</span> <span class="kc">True</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="c1"># TODO move to a NamedTuple, once we move to Python3.6</span>
<span class="c1"># https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple/34269877</span>
<span class="k">class</span> <span class="nc">CsvDataset</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filepath</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">is_episodic</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">filepath</span> <span class="o">=</span> <span class="n">filepath</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_episodic</span> <span class="o">=</span> <span class="n">is_episodic</span>
<span class="k">class</span> <span class="nc">PickledReplayBuffer</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filepath</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">filepath</span> <span class="o">=</span> <span class="n">filepath</span>
</pre></div>
</div>
@@ -463,7 +463,12 @@
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while deleting NFS PVC&quot;</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">return</span> <span class="kc">True</span></div>
<span class="k">return</span> <span class="kc">True</span>
<span class="k">def</span> <span class="nf">setup_checkpoint_dir</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crd</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">if</span> <span class="n">crd</span><span class="p">:</span>
<span class="c1"># TODO: find a way to upload this to the deployed nfs store.</span>
<span class="k">pass</span></div>
</pre></div>
</div>
@@ -257,6 +257,9 @@
<span class="k">return</span> <span class="kc">True</span>
<span class="k">def</span> <span class="nf">save_to_store</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_save_to_store</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_save_to_store</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and</span>
<span class="sd"> uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.</span>
@@ -268,24 +271,32 @@
<span class="c1"># Acquire lock</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">put_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">LOCKFILE</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">(</span><span class="sa">b</span><span class="s1">&#39;&#39;</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">state_file</span> <span class="o">=</span> <span class="n">CheckpointStateFile</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">))</span>
<span class="n">state_file</span> <span class="o">=</span> <span class="n">CheckpointStateFile</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">))</span>
<span class="k">if</span> <span class="n">state_file</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
<span class="n">ckpt_state</span> <span class="o">=</span> <span class="n">state_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
<span class="n">checkpoint_file</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">for</span> <span class="n">root</span><span class="p">,</span> <span class="n">dirs</span><span class="p">,</span> <span class="n">files</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">walk</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">):</span>
<span class="k">for</span> <span class="n">root</span><span class="p">,</span> <span class="n">dirs</span><span class="p">,</span> <span class="n">files</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">walk</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">):</span>
<span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">files</span><span class="p">:</span>
<span class="k">if</span> <span class="n">filename</span> <span class="o">==</span> <span class="n">CheckpointStateFile</span><span class="o">.</span><span class="n">checkpoint_state_filename</span><span class="p">:</span>
<span class="n">checkpoint_file</span> <span class="o">=</span> <span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="n">filename</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="n">ckpt_state</span><span class="o">.</span><span class="n">name</span><span class="p">):</span>
<span class="n">abs_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fput_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">rel_name</span><span class="p">,</span> <span class="n">abs_name</span><span class="p">)</span>
<span class="n">abs_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_file</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">checkpoint_file</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fput_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">rel_name</span><span class="p">,</span> <span class="n">abs_name</span><span class="p">)</span>
<span class="c1"># upload Finished if present</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">FINISHED</span><span class="o">.</span><span class="n">value</span><span class="p">)):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">put_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">FINISHED</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">(</span><span class="sa">b</span><span class="s1">&#39;&#39;</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># upload Ready if present</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">)):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">put_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">(</span><span class="sa">b</span><span class="s1">&#39;&#39;</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># release lock</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">remove_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">LOCKFILE</span><span class="o">.</span><span class="n">value</span><span class="p">)</span>
@@ -301,6 +312,7 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span><span class="p">,</span> <span class="s1">&#39;gifs&#39;</span><span class="p">)):</span>
<span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span><span class="p">,</span> <span class="s1">&#39;gifs&#39;</span><span class="p">)):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fput_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span><span class="p">,</span> <span class="s1">&#39;gifs&#39;</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
<span class="k">except</span> <span class="n">ResponseError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while saving to S3&quot;</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
@@ -337,6 +349,18 @@
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="k">pass</span>
<span class="c1"># Check if there&#39;s a ready file</span>
<span class="n">objects</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">list_objects_v2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">next</span><span class="p">(</span><span class="n">objects</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">try</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fget_object</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">,</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">))</span>
<span class="p">)</span>
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="k">pass</span>
<span class="n">checkpoint_state</span> <span class="o">=</span> <span class="n">state_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
<span class="k">if</span> <span class="n">checkpoint_state</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">objects</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">list_objects_v2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="n">checkpoint_state</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">recursive</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
@@ -346,7 +370,11 @@
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fget_object</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">obj</span><span class="o">.</span><span class="n">object_name</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
<span class="k">except</span> <span class="n">ResponseError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while loading from S3&quot;</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span></div>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while loading from S3&quot;</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setup_checkpoint_dir</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crd</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">if</span> <span class="n">crd</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_save_to_store</span><span class="p">(</span><span class="n">crd</span><span class="p">)</span></div>
</pre></div>
</div>
@@ -480,7 +480,6 @@
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="c1"># frame skip and max between consecutive frames</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_robotics_env</span> <span class="o">=</span> <span class="s1">&#39;robotics&#39;</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_mujoco_env</span> <span class="o">=</span> <span class="s1">&#39;mujoco&#39;</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_roboschool_env</span> <span class="o">=</span> <span class="s1">&#39;roboschool&#39;</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_atari_env</span> <span class="o">=</span> <span class="s1">&#39;Atari&#39;</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
@@ -501,7 +500,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">state_space</span> <span class="o">=</span> <span class="n">StateSpace</span><span class="p">({})</span>
<span class="c1"># observations</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">,</span> <span class="n">gym</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">dict_space</span><span class="o">.</span><span class="n">Dict</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">,</span> <span class="n">gym</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">dict</span><span class="o">.</span><span class="n">Dict</span><span class="p">):</span>
<span class="n">state_space</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">}</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">state_space</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="o">.</span><span class="n">spaces</span>
@@ -665,38 +664,11 @@
<span class="c1"># initialize the number of lives</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_update_ale_lives</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">_set_mujoco_camera</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">camera_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> This function can be used to set the camera for rendering the mujoco simulator</span>
<span class="sd"> :param camera_idx: The index of the camera to use. Should be defined in the model</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span><span class="o">.</span><span class="n">cam</span><span class="o">.</span><span class="n">fixedcamid</span> <span class="o">!=</span> <span class="n">camera_idx</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">mujoco_py.generated</span> <span class="k">import</span> <span class="n">const</span>
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span><span class="o">.</span><span class="n">cam</span><span class="o">.</span><span class="n">type</span> <span class="o">=</span> <span class="n">const</span><span class="o">.</span><span class="n">CAMERA_FIXED</span>
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span><span class="o">.</span><span class="n">cam</span><span class="o">.</span><span class="n">fixedcamid</span> <span class="o">=</span> <span class="n">camera_idx</span>
<span class="k">def</span> <span class="nf">_get_robotics_image</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">()</span>
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">_get_viewer</span><span class="p">()</span><span class="o">.</span><span class="n">read_pixels</span><span class="p">(</span><span class="mi">1600</span><span class="p">,</span> <span class="mi">900</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="kc">False</span><span class="p">)[::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">misc</span><span class="o">.</span><span class="n">imresize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="p">(</span><span class="mi">270</span><span class="p">,</span> <span class="mi">480</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="k">return</span> <span class="n">image</span>
<span class="k">def</span> <span class="nf">_render</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">&#39;human&#39;</span><span class="p">)</span>
<span class="c1"># required for setting up a fixed camera for mujoco</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_mujoco_env</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_roboschool_env</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set_mujoco_camera</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_rendered_image</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_robotics_env</span><span class="p">:</span>
<span class="c1"># necessary for fetch since the rendered image is cropped to an irrelevant part of the simulator</span>
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_robotics_image</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">&#39;rgb_array&#39;</span><span class="p">)</span>
<span class="c1"># required for setting up a fixed camera for mujoco</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_mujoco_env</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_roboschool_env</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set_mujoco_camera</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">&#39;rgb_array&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">image</span>
<span class="k">def</span> <span class="nf">get_target_success_rate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
@@ -179,6 +179,7 @@
<h1>Source code for rl_coach.memories.episodic.episodic_experience_replay</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1">#</span>
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
@@ -193,14 +194,19 @@
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">import</span> <span class="nn">ast</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Any</span>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">random</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">Transition</span><span class="p">,</span> <span class="n">Episode</span>
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.memory</span> <span class="k">import</span> <span class="n">Memory</span><span class="p">,</span> <span class="n">MemoryGranularity</span><span class="p">,</span> <span class="n">MemoryParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">ReaderWriterLock</span>
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">ReaderWriterLock</span><span class="p">,</span> <span class="n">ProgressBar</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">CsvDataset</span>
<span class="k">class</span> <span class="nc">EpisodicExperienceReplayParameters</span><span class="p">(</span><span class="n">MemoryParameters</span><span class="p">):</span>
@@ -208,6 +214,7 @@
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">n_step</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># for OPE we&#39;ll want a value &lt; 1</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
@@ -220,7 +227,9 @@
<span class="sd"> calculations of total return and other values that depend on the sequential behavior of the transitions</span>
<span class="sd"> in the episode.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">MemoryGranularity</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span><span class="o">=</span><span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">),</span> <span class="n">n_step</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">MemoryGranularity</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">),</span> <span class="n">n_step</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
<span class="n">train_to_eval_ratio</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param max_size: the maximum number of transitions or episodes to hold in the memory</span>
<span class="sd"> &quot;&quot;&quot;</span>
@@ -232,8 +241,11 @@
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions_in_complete_episodes</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span> <span class="o">=</span> <span class="n">ReaderWriterLock</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># used in batch-rl</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># used in batch-rl</span>
<span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">=</span> <span class="n">train_to_eval_ratio</span> <span class="c1"># used in batch-rl</span>
<span class="k">def</span> <span class="nf">length</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">length</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get the number of episodes in the ER (even if they are not complete)</span>
<span class="sd"> &quot;&quot;&quot;</span>
@@ -255,6 +267,9 @@
<span class="k">def</span> <span class="nf">num_transitions_in_complete_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions_in_complete_episodes</span>
<span class="k">def</span> <span class="nf">get_last_training_set_episode_id</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span>
<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">is_consecutive_transitions</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sample a batch of transitions from the replay buffer. If the requested size is larger than the number</span>
@@ -272,7 +287,7 @@
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">transition_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">length</span><span class="p">())</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">transition_idx</span><span class="o">-</span><span class="n">size</span><span class="p">:</span><span class="n">transition_idx</span><span class="p">]</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">transition_idx</span> <span class="o">-</span> <span class="n">size</span><span class="p">:</span><span class="n">transition_idx</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">transitions_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">(),</span> <span class="n">size</span><span class="o">=</span><span class="n">size</span><span class="p">)</span>
<span class="n">batch</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">transitions_idx</span><span class="p">]</span>
@@ -285,6 +300,78 @@
<span class="k">return</span> <span class="n">batch</span>
<span class="k">def</span> <span class="nf">get_episode_for_transition</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">:</span> <span class="n">Transition</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="n">Episode</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get the episode from which that transition came from.</span>
<span class="sd"> :param transition: The transition to lookup the episode for</span>
<span class="sd"> :return: (Episode number, the episode) or (-1, None) if could not find a matching episode.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">episode</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">):</span>
<span class="k">if</span> <span class="n">transition</span> <span class="ow">in</span> <span class="n">episode</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
<span class="k">return</span> <span class="n">i</span><span class="p">,</span> <span class="n">episode</span>
<span class="k">return</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">shuffle_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Shuffle all the episodes in the replay buffer</span>
<span class="sd"> :return:</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">transitions</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="o">.</span><span class="n">transitions</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">get_shuffled_data_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.</span>
<span class="sd"> If the requested size is larger than the number of samples available in the replay buffer then the batch will</span>
<span class="sd"> return empty. The last returned batch may be smaller than the size requested, to accommodate for all the</span>
<span class="sd"> transitions in the replay buffer.</span>
<span class="sd"> :param size: the size of the batch to return</span>
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;train_to_eval_ratio should be in the (0, 1] range.&#39;</span><span class="p">)</span>
<span class="n">transition</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="nb">round</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">())]</span>
<span class="n">episode_num</span><span class="p">,</span> <span class="n">episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_episode_for_transition</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">=</span> <span class="n">episode_num</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="o">=</span> \
<span class="nb">len</span><span class="p">([</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="p">])</span>
<span class="n">shuffled_transition_indices</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span><span class="p">))</span>
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span>
<span class="c1"># The last batch drawn will usually be &lt; batch_size (=the size variable)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span> <span class="o">/</span> <span class="n">size</span><span class="p">)):</span>
<span class="n">sample_data</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">shuffled_transition_indices</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span><span class="p">:</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">size</span><span class="p">]]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
<span class="k">yield</span> <span class="n">sample_data</span>
<span class="k">def</span> <span class="nf">get_all_complete_episodes_transitions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get all the transitions from all the complete episodes in the buffer</span>
<span class="sd"> :return: a list of transitions</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">()]</span>
<span class="k">def</span> <span class="nf">get_all_complete_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Episode</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get all the transitions from all the complete episodes in the buffer</span>
<span class="sd"> :return: a list of transitions</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_complete_episodes</span><span class="p">())</span>
<span class="k">def</span> <span class="nf">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_episode_id</span><span class="p">,</span> <span class="n">end_episode_id</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Episode</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get all the transitions from all the complete episodes in the buffer matching the given episode range</span>
<span class="sd"> :return: a list of transitions</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">start_episode_id</span><span class="p">:</span><span class="n">end_episode_id</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">_enforce_max_length</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Make sure that the size of the replay buffer does not pass the maximum size allowed.</span>
@@ -368,7 +455,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">store_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode</span><span class="p">:</span> <span class="n">Episode</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">store_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode</span><span class="p">:</span> <span class="n">Episode</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Store a new episode in the memory.</span>
<span class="sd"> :param episode: the new episode to store</span>
@@ -391,7 +478,7 @@
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">get_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
<span class="k">def</span> <span class="nf">get_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns the episode in the given index. If the episode does not exist, returns None instead.</span>
<span class="sd"> :param episode_index: the index of the episode to return</span>
@@ -436,7 +523,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
<span class="c1"># for API compatibility</span>
<span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
<span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns the episode in the given index. If the episode does not exist, returns None instead.</span>
<span class="sd"> :param episode_index: the index of the episode to return</span>
@@ -494,7 +581,51 @@
<span class="n">mean</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">([</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="k">for</span> <span class="n">transition</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
<span class="k">return</span> <span class="n">mean</span></div>
<span class="k">return</span> <span class="n">mean</span>
<span class="k">def</span> <span class="nf">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">csv_dataset</span><span class="p">:</span> <span class="n">CsvDataset</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Restore the replay buffer contents from a csv file.</span>
<span class="sd"> The csv file is assumed to include a list of transitions.</span>
<span class="sd"> :param csv_dataset: A construct which holds the dataset parameters</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">csv_dataset</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">)</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;Warning! The number of transitions to load into the replay buffer (</span><span class="si">{}</span><span class="s2">) is &quot;</span>
<span class="s2">&quot;bigger than the max size of the replay buffer (</span><span class="si">{}</span><span class="s2">). The excessive transitions will &quot;</span>
<span class="s2">&quot;not be stored.&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
<span class="n">episode_ids</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s1">&#39;episode_id&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">()</span>
<span class="n">progress_bar</span> <span class="o">=</span> <span class="n">ProgressBar</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episode_ids</span><span class="p">))</span>
<span class="n">state_columns</span> <span class="o">=</span> <span class="p">[</span><span class="n">col</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">df</span><span class="o">.</span><span class="n">columns</span> <span class="k">if</span> <span class="n">col</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;state_feature&#39;</span><span class="p">)]</span>
<span class="k">for</span> <span class="n">e_id</span> <span class="ow">in</span> <span class="n">episode_ids</span><span class="p">:</span>
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">e_id</span><span class="p">)</span>
<span class="n">df_episode_transitions</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">df</span><span class="p">[</span><span class="s1">&#39;episode_id&#39;</span><span class="p">]</span> <span class="o">==</span> <span class="n">e_id</span><span class="p">]</span>
<span class="n">episode</span> <span class="o">=</span> <span class="n">Episode</span><span class="p">()</span>
<span class="k">for</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">current_transition</span><span class="p">),</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">next_transition</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">df_episode_transitions</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">(),</span>
<span class="n">df_episode_transitions</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">()):</span>
<span class="n">state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">current_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
<span class="n">next_state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">next_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
<span class="n">episode</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span>
<span class="n">Transition</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">state</span><span class="p">},</span>
<span class="n">action</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">],</span> <span class="n">reward</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;reward&#39;</span><span class="p">],</span>
<span class="n">next_state</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">next_state</span><span class="p">},</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">info</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;all_action_probabilities&#39;</span><span class="p">:</span>
<span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;all_action_probabilities&#39;</span><span class="p">])}))</span>
<span class="c1"># Set the last transition to end the episode</span>
<span class="k">if</span> <span class="n">csv_dataset</span><span class="o">.</span><span class="n">is_episodic</span><span class="p">:</span>
<span class="n">episode</span><span class="o">.</span><span class="n">get_last_transition</span><span class="p">()</span><span class="o">.</span><span class="n">game_over</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">store_episode</span><span class="p">(</span><span class="n">episode</span><span class="p">)</span>
<span class="c1"># close the progress bar</span>
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episode_ids</span><span class="p">))</span>
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shuffle_episodes</span><span class="p">()</span></div>
</pre></div>
</div>
@@ -194,10 +194,10 @@
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Any</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">pickle</span>
<span class="kn">import</span> <span class="nn">sys</span>
<span class="kn">import</span> <span class="nn">time</span>
<span class="kn">import</span> <span class="nn">random</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
@@ -252,7 +252,6 @@
<span class="sd"> Sample a batch of transitions form the replay buffer. If the requested size is larger than the number</span>
<span class="sd"> of samples available in the replay buffer then the batch will return empty.</span>
<span class="sd"> :param size: the size of the batch to sample</span>
<span class="sd"> :param beta: the beta parameter used for importance sampling</span>
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
@@ -272,6 +271,28 @@
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
<span class="k">return</span> <span class="n">batch</span>
<span class="k">def</span> <span class="nf">get_shuffled_data_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.</span>
<span class="sd"> If the requested size is larger than the number of samples available in the replay buffer then the batch will</span>
<span class="sd"> return empty. The last returned batch may be smaller than the size requested, to accommodate for all the</span>
<span class="sd"> transitions in the replay buffer.</span>
<span class="sd"> :param size: the size of the batch to return</span>
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
<span class="n">shuffled_transition_indices</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">)))</span>
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span>
<span class="c1"># we deliberately drop some of the ending data which is left after dividing to batches of size `size`</span>
<span class="c1"># for i in range(math.ceil(len(shuffled_transition_indices) / size)):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span> <span class="o">/</span> <span class="n">size</span><span class="p">)):</span>
<span class="n">sample_data</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">shuffled_transition_indices</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span><span class="p">:</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">size</span><span class="p">]]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
<span class="k">yield</span> <span class="n">sample_data</span>
<span class="k">def</span> <span class="nf">_enforce_max_length</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Make sure that the size of the replay buffer does not pass the maximum size allowed.</span>
@@ -395,7 +416,7 @@
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s1">&#39;wb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">file</span><span class="p">:</span>
<span class="n">pickle</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">,</span> <span class="n">file</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">load_pickled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Restore the replay buffer contents from a pickle file.</span>
<span class="sd"> The pickle file is assumed to include a list of transitions.</span>
@@ -418,6 +439,7 @@
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">transition_idx</span><span class="p">)</span>
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span></div>
</pre></div>
</div>
@@ -298,7 +298,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">s3_access_key</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;ACCESS_KEY_ID&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">s3_secret_key</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;SECRET_ACCESS_KEY&#39;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crd</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Deploys the memory backend and data stores if required.</span>
<span class="sd"> &quot;&quot;&quot;</span>
@@ -308,6 +308,9 @@
<span class="k">return</span> <span class="kc">False</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;nfs&quot;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nfs_pvc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">get_info</span><span class="p">()</span>
<span class="c1"># Upload checkpoints in checkpoint_restore_dir (if provided) to the data store</span>
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">setup_checkpoint_dir</span><span class="p">(</span><span class="n">crd</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="k">def</span> <span class="nf">deploy_trainer</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
@@ -321,7 +324,6 @@
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--memory_backend_params&#39;</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">memory_backend_parameters</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;--data_store_params&#39;</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
<span class="n">name</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">&quot;nfs&quot;</span><span class="p">:</span>
@@ -346,7 +348,7 @@
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;nfs-pvc&quot;</span><span class="p">,</span>
<span class="n">persistent_volume_claim</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nfs_pvc</span>
<span class="p">)],</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;OnFailure&#39;</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
@@ -365,7 +367,7 @@
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;app&#39;</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;OnFailure&#39;</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
@@ -427,7 +429,7 @@
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;nfs-pvc&quot;</span><span class="p">,</span>
<span class="n">persistent_volume_claim</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nfs_pvc</span>
<span class="p">)],</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;OnFailure&#39;</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
@@ -446,7 +448,7 @@
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;app&#39;</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;OnFailure&#39;</span>
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">&#39;Never&#39;</span>
<span class="p">)</span>
<span class="p">)</span>
@@ -496,7 +498,7 @@
<span class="k">return</span>
<span class="k">for</span> <span class="n">pod</span> <span class="ow">in</span> <span class="n">pods</span><span class="o">.</span><span class="n">items</span><span class="p">:</span>
<span class="n">Process</span><span class="p">(</span><span class="n">target</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tail_log_file</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">namespace</span><span class="p">,</span> <span class="n">path</span><span class="p">))</span><span class="o">.</span><span class="n">start</span><span class="p">()</span>
<span class="n">Process</span><span class="p">(</span><span class="n">target</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tail_log_file</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">namespace</span><span class="p">,</span> <span class="n">path</span><span class="p">),</span> <span class="n">daemon</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span><span class="o">.</span><span class="n">start</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">_tail_log_file</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pod_name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">,</span> <span class="n">namespace</span><span class="p">,</span> <span class="n">path</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
@@ -528,7 +530,7 @@
<span class="k">if</span> <span class="ow">not</span> <span class="n">pod</span><span class="p">:</span>
<span class="k">return</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tail_log</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tail_log</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">tail_log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pod_name</span><span class="p">,</span> <span class="n">corev1_api</span><span class="p">):</span>
<span class="k">while</span> <span class="kc">True</span><span class="p">:</span>
@@ -562,9 +564,9 @@
<span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">waiting</span><span class="o">.</span><span class="n">reason</span> <span class="o">==</span> <span class="s1">&#39;CrashLoopBackOff&#39;</span> <span class="ow">or</span> \
<span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">waiting</span><span class="o">.</span><span class="n">reason</span> <span class="o">==</span> <span class="s1">&#39;ImagePullBackOff&#39;</span> <span class="ow">or</span> \
<span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">waiting</span><span class="o">.</span><span class="n">reason</span> <span class="o">==</span> <span class="s1">&#39;ErrImagePull&#39;</span><span class="p">:</span>
<span class="k">return</span>
<span class="k">return</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">terminated</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span>
<span class="k">return</span> <span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">terminated</span><span class="o">.</span><span class="n">exit_code</span>
<span class="k">def</span> <span class="nf">undeploy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
+1 -1
View File
@@ -828,7 +828,7 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">state</span><span class="p">:</span> <span class="n">StateSpace</span><span class="p">,</span>
<span class="n">goal</span><span class="p">:</span> <span class="n">ObservationSpace</span><span class="p">,</span>
<span class="n">goal</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">ObservationSpace</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span>
<span class="n">action</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">,</span>
<span class="n">reward</span><span class="p">:</span> <span class="n">RewardSpace</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">state</span> <span class="o">=</span> <span class="n">state</span>
@@ -21,6 +21,7 @@ A detailed description of those algorithms can be found by navigating to each of
imitation/cil
policy_optimization/cppo
policy_optimization/ddpg
policy_optimization/sac
other/dfp
value_optimization/double_dqn
value_optimization/dqn
@@ -38,6 +38,7 @@ Each update perform the following procedure:
.. math:: \text{where} \quad \bar{\rho}_{t} = \min{\left\{c,\rho_t\right\}},\quad \rho_t=\frac{\pi (a_t \mid s_t)}{\mu (a_t \mid s_t)}
3. **Accumulate gradients:**
:math:`\bullet` **Policy gradients (with bias correction):**
.. math:: \hat{g}_t^{policy} & = & \bar{\rho}_{t} \nabla \log \pi (a_t \mid s_t) [Q^{ret}(s_t,a_t) - V(s_t)] \\
@@ -0,0 +1,49 @@
Soft Actor-Critic
============
**Actions space:** Continuous
**References:** `Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor <https://arxiv.org/abs/1801.01290>`_
Network Structure
-----------------
.. image:: /_static/img/design_imgs/sac.png
:align: center
Algorithm Description
---------------------
Choosing an action - Continuous actions
+++++++++++++++++++++++++++++++++++++
The policy network is used in order to predict mean and log std for each action. While training, a sample is taken
from a Gaussian distribution with these mean and std values. When testing, the agent can choose deterministically
by picking the mean value or sample from a gaussian distribution like in training.
Training the network
++++++++++++++++++++
Start by sampling a batch :math:`B` of transitions from the experience replay.
* To train the **Q network**, use the following targets:
.. math:: y_t^Q=r(s_t,a_t)+\gamma \cdot V(s_{t+1})
The state value used in the above target is acquired by running the target state value network.
* To train the **State Value network**, use the following targets:
.. math:: y_t^V = \min_{i=1,2}Q_i(s_t,\tilde{a}) - log\pi (\tilde{a} \vert s),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)
The state value network is trained using a sample-based approximation of the connection between and state value and state
action values, The actions used for constructing the target are **not** sampled from the replay buffer, but rather sampled
from the current policy.
* To train the **actor network**, use the following equation:
.. math:: \nabla_{\theta} J \approx \nabla_{\theta} \frac{1}{\vert B \vert} \sum_{s_t\in B} \left( Q \left(s_t, \tilde{a}_\theta(s_t)\right) - log\pi_{\theta}(\tilde{a}_{\theta}(s_t)\vert s_t) \right),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)
After every training step, do a soft update of the V target network's weights from the online networks.
.. autoclass:: rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters
@@ -198,6 +198,14 @@ The algorithms are ordered by their release date in descending order.
improve stability it also employs bias correction and trust region optimization techniques.
</span>
</div>
<div class="algorithm continuous off-policy" data-year="201808">
<span class="badge">
<a href="components/agents/policy_optimization/sac.html">SAC</a>
<br>
Soft Actor-Critic is an algorithm which optimizes a stochastic policy in an off-policy way.
One of the key features of SAC is that it solves a maximum entropy reinforcement learning problem.
</span>
</div>
<div class="algorithm continuous off-policy" data-year="201509">
<span class="badge">
<a href="components/agents/policy_optimization/ddpg.html">DDPG</a>
+2 -1
View File
@@ -113,7 +113,8 @@ In Coach, this can be done in two steps -
.. code-block:: python
coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=\"<experiment dir>/replay_buffer.p\"'
from rl_coach.core_types import PickledReplayBuffer
coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=PickledReplayBuffer(\"<experiment dir>/replay_buffer.p\"')
Visualizations
-61
View File
@@ -1,61 +0,0 @@
/* Docs background */
.wy-side-nav-search{
background-color: #043c74;
}
/* Mobile version */
.wy-nav-top{
background-color: #043c74;
}
.green {
color: green;
}
.red {
color: red;
}
.blue {
color: blue;
}
.yellow {
color: yellow;
}
.badge {
border: 2px;
border-style: solid;
border-color: #6C8EBF;
border-radius: 5px;
padding: 3px 15px 3px 15px;
margin: 5px;
display: inline-block;
font-weight: bold;
font-size: 16px;
background: #DAE8FC;
}
.badge:hover {
cursor: pointer;
}
.badge > a {
color: black;
}
.bordered-container {
border: 0px;
border-style: solid;
border-radius: 8px;
padding: 15px;
margin-bottom: 20px;
background: #f2f2f2;
}
.questionnaire {
font-size: 1.2em;
line-height: 1.5em;
}
+10 -6
View File
@@ -276,19 +276,22 @@ of the trace tests suite.</li>
<h2>TaskParameters<a class="headerlink" href="#taskparameters" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.base_parameters.TaskParameters">
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">TaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks = &lt;Frameworks.tensorflow: 'TensorFlow'&gt;</em>, <em>evaluate_only: bool = False</em>, <em>use_cpu: bool = False</em>, <em>experiment_path='/tmp'</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_dir=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em>, <em>num_gpu: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#TaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.TaskParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">TaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks = &lt;Frameworks.tensorflow: 'TensorFlow'&gt;</em>, <em>evaluate_only: int = None</em>, <em>use_cpu: bool = False</em>, <em>experiment_path='/tmp'</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_dir=None</em>, <em>checkpoint_restore_path=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em>, <em>num_gpu: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#TaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.TaskParameters" title="Permalink to this definition"></a></dt>
<dd><table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>framework_type</strong> deep learning framework type. currently only tensorflow is supported</li>
<li><strong>evaluate_only</strong> the task will be used only for evaluating the model</li>
<li><strong>evaluate_only</strong> if not None, the task will be used only for evaluating the model for the given number of steps.
A value of 0 means that task will be evaluated for an infinite number of steps.</li>
<li><strong>use_cpu</strong> use the cpu for this task</li>
<li><strong>experiment_path</strong> the path to the directory which will store all the experiment outputs</li>
<li><strong>seed</strong> a seed to use for the random numbers generator</li>
<li><strong>checkpoint_save_secs</strong> the number of seconds between each checkpoint saving</li>
<li><strong>checkpoint_restore_dir</strong> the directory to restore the checkpoints from</li>
<li><strong>checkpoint_restore_dir</strong> [DEPECRATED - will be removed in one of the next releases - switch to checkpoint_restore_path]
the dir to restore the checkpoints from</li>
<li><strong>checkpoint_restore_path</strong> the path to restore the checkpoints from</li>
<li><strong>checkpoint_save_dir</strong> the directory to store the checkpoints in</li>
<li><strong>export_onnx_graph</strong> If set to True, this will export an onnx graph each time a checkpoint is saved</li>
<li><strong>apply_stop_condition</strong> If set to True, this will apply the stop condition defined by reaching a target success rate</li>
@@ -305,14 +308,15 @@ of the trace tests suite.</li>
<h2>DistributedTaskParameters<a class="headerlink" href="#distributedtaskparameters" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.base_parameters.DistributedTaskParameters">
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">DistributedTaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks</em>, <em>parameters_server_hosts: str</em>, <em>worker_hosts: str</em>, <em>job_type: str</em>, <em>task_index: int</em>, <em>evaluate_only: bool = False</em>, <em>num_tasks: int = None</em>, <em>num_training_tasks: int = None</em>, <em>use_cpu: bool = False</em>, <em>experiment_path=None</em>, <em>dnd=None</em>, <em>shared_memory_scratchpad=None</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_dir=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#DistributedTaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.DistributedTaskParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">DistributedTaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks</em>, <em>parameters_server_hosts: str</em>, <em>worker_hosts: str</em>, <em>job_type: str</em>, <em>task_index: int</em>, <em>evaluate_only: int = None</em>, <em>num_tasks: int = None</em>, <em>num_training_tasks: int = None</em>, <em>use_cpu: bool = False</em>, <em>experiment_path=None</em>, <em>dnd=None</em>, <em>shared_memory_scratchpad=None</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_path=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#DistributedTaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.DistributedTaskParameters" title="Permalink to this definition"></a></dt>
<dd><table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>framework_type</strong> deep learning framework type. currently only tensorflow is supported</li>
<li><strong>evaluate_only</strong> the task will be used only for evaluating the model</li>
<li><strong>evaluate_only</strong> if not None, the task will be used only for evaluating the model for the given number of steps.
A value of 0 means that task will be evaluated for an infinite number of steps.</li>
<li><strong>parameters_server_hosts</strong> comma-separated list of hostname:port pairs to which the parameter servers are
assigned</li>
<li><strong>worker_hosts</strong> comma-separated list of hostname:port pairs to which the workers are assigned</li>
@@ -325,7 +329,7 @@ assigned</li>
<li><strong>dnd</strong> an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.</li>
<li><strong>seed</strong> a seed to use for the random numbers generator</li>
<li><strong>checkpoint_save_secs</strong> the number of seconds between each checkpoint saving</li>
<li><strong>checkpoint_restore_dir</strong> the directory to restore the checkpoints from</li>
<li><strong>checkpoint_restore_path</strong> the path to restore the checkpoints from</li>
<li><strong>checkpoint_save_dir</strong> the directory to store the checkpoints in</li>
<li><strong>export_onnx_graph</strong> If set to True, this will export an onnx graph each time a checkpoint is saved</li>
<li><strong>apply_stop_condition</strong> If set to True, this will apply the stop condition defined by reaching a target success rate</li>
+1
View File
@@ -121,6 +121,7 @@
<li class="toctree-l2"><a class="reference internal" href="cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -121,6 +121,7 @@
</li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
+22 -23
View File
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -221,6 +222,7 @@ A detailed description of those algorithms can be found by navigating to each of
<li class="toctree-l1"><a class="reference internal" href="imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l1"><a class="reference internal" href="other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l1"><a class="reference internal" href="value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l1"><a class="reference internal" href="value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -280,13 +282,15 @@ used for visualization purposes, such as printing to the screen, rendering, and
</table>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.act">
<code class="descname">act</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.act"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.act" title="Permalink to this definition"></a></dt>
<code class="descname">act</code><span class="sig-paren">(</span><em>action: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List] = None</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.act"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.act" title="Permalink to this definition"></a></dt>
<dd><p>Given the agents current knowledge, decide on the next action to apply to the environment</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>action</strong> An action to take, overriding whatever the current policy is</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
</tr>
</tbody>
</table>
@@ -357,26 +361,6 @@ for creating the network.</p>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.emulate_act_on_trainer">
<code class="descname">emulate_act_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.emulate_act_on_trainer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.emulate_act_on_trainer" title="Permalink to this definition"></a></dt>
<dd><p>This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given the agents current knowledge, decide on the next action to apply to the environment
:return: an action and a dictionary containing any additional info from the action decision process</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.emulate_observe_on_trainer">
<code class="descname">emulate_observe_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.emulate_observe_on_trainer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.emulate_observe_on_trainer" title="Permalink to this definition"></a></dt>
<dd><p>This emulates the observe using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given a response from the environment, distill the observation from it and store it for later use.
The response should be a dictionary containing the performed action, the new observation and measurements,
the reward, a game over flag and any additional information necessary.
:return:</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.get_predictions">
<code class="descname">get_predictions</code><span class="sig-paren">(</span><em>states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.get_predictions"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.get_predictions" title="Permalink to this definition"></a></dt>
@@ -540,7 +524,7 @@ given observation</td>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.prepare_batch_for_inference">
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.core.multiarray.array]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.prepare_batch_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.array]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.prepare_batch_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.</p>
<table class="docutils field-list" frame="void" rules="none">
@@ -632,6 +616,21 @@ by val, and by the current phase set in self.phase.</p>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.run_off_policy_evaluation">
<code class="descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.agent.Agent.run_off_policy_evaluation" title="Permalink to this definition"></a></dt>
<dd><p>Run off-policy evaluation estimators to evaluate the trained policy performance against a dataset.
Should only be implemented for off-policy RL algorithms.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">None</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference">
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.run_pre_network_filter_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference" title="Permalink to this definition"></a></dt>
+3 -2
View File
@@ -30,7 +30,7 @@
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link rel="next" title="Double DQN" href="../value_optimization/double_dqn.html" />
<link rel="prev" title="Deep Deterministic Policy Gradient" href="../policy_optimization/ddpg.html" />
<link rel="prev" title="Soft Actor-Critic" href="../policy_optimization/sac.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Direct Future Prediction</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
<li class="toctree-l3"><a class="reference internal" href="#algorithm-description">Algorithm Description</a><ul>
@@ -296,7 +297,7 @@ have a different scale and you want to normalize them to the same scale.</li>
<a href="../value_optimization/double_dqn.html" class="btn btn-neutral float-right" title="Double DQN" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="../policy_optimization/ddpg.html" class="btn btn-neutral" title="Deep Deterministic Policy Gradient" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
<a href="../policy_optimization/sac.html" class="btn btn-neutral" title="Soft Actor-Critic" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
@@ -122,6 +122,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -122,6 +122,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -247,21 +248,20 @@ and <span class="math notranslate nohighlight">\(n\)</span> (replay ratio) off-p
\[\text{where} \quad \bar{\rho}_{t} = \min{\left\{c,\rho_t\right\}},\quad \rho_t=\frac{\pi (a_t \mid s_t)}{\mu (a_t \mid s_t)}\]</div>
</div></blockquote>
</li>
<li><dl class="first docutils">
<dt><strong>Accumulate gradients:</strong></dt>
<dd><p class="first"><span class="math notranslate nohighlight">\(\bullet\)</span> <strong>Policy gradients (with bias correction):</strong></p>
<li><p class="first"><strong>Accumulate gradients:</strong></p>
<blockquote>
<div><p><span class="math notranslate nohighlight">\(\bullet\)</span> <strong>Policy gradients (with bias correction):</strong></p>
<blockquote>
<div><div class="math notranslate nohighlight">
\[\begin{split}\hat{g}_t^{policy} &amp; = &amp; \bar{\rho}_{t} \nabla \log \pi (a_t \mid s_t) [Q^{ret}(s_t,a_t) - V(s_t)] \\
&amp; &amp; + \mathbb{E}_{a \sim \pi} \left(\left[\frac{\rho_t(a)-c}{\rho_t(a)}\right] \nabla \log \pi (a \mid s_t) [Q(s_t,a) - V(s_t)] \right)\end{split}\]</div>
</div></blockquote>
<p><span class="math notranslate nohighlight">\(\bullet\)</span> <strong>Q-Head gradients (MSE):</strong></p>
<blockquote class="last">
<blockquote>
<div><div class="math notranslate nohighlight">
\[\begin{split}\hat{g}_t^{Q} = (Q^{ret}(s_t,a_t) - Q(s_t,a_t)) \nabla Q(s_t,a_t)\\\end{split}\]</div>
</div></blockquote>
</dd>
</dl>
</div></blockquote>
</li>
<li><p class="first"><strong>(Optional) Trust region update:</strong> change the policy loss gradient w.r.t network output:</p>
<blockquote>
@@ -122,6 +122,7 @@
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -29,7 +29,7 @@
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link rel="next" title="Direct Future Prediction" href="../other/dfp.html" />
<link rel="next" title="Soft Actor-Critic" href="sac.html" />
<link rel="prev" title="Clipped Proximal Policy Optimization" href="cppo.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
@@ -122,6 +122,7 @@
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -297,7 +298,7 @@ values. If set to False, the terminal states reward will be taken as the target
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../other/dfp.html" class="btn btn-neutral float-right" title="Direct Future Prediction" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="sac.html" class="btn btn-neutral float-right" title="Soft Actor-Critic" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="cppo.html" class="btn btn-neutral" title="Clipped Proximal Policy Optimization" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
@@ -0,0 +1,343 @@
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Soft Actor-Critic &mdash; Reinforcement Learning Coach 0.11.0 documentation</title>
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link rel="next" title="Direct Future Prediction" href="../other/dfp.html" />
<link rel="prev" title="Deep Deterministic Policy Gradient" href="ddpg.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
<script src="../../../_static/js/modernizr.min.js"></script>
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search">
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Intro</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
</ul>
<p class="caption"><span class="caption-text">Design</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
</ul>
<p class="caption"><span class="caption-text">Contributing</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
</ul>
<p class="caption"><span class="caption-text">Components</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Agents</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="ac.html">Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="acer.html">ACER</a></li>
<li class="toctree-l2"><a class="reference internal" href="../imitation/bc.html">Behavioral Cloning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/bs_dqn.html">Bootstrapped DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/categorical_dqn.html">Categorical DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Soft Actor-Critic</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
<li class="toctree-l3"><a class="reference internal" href="#algorithm-description">Algorithm Description</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#choosing-an-action-continuous-actions">Choosing an action - Continuous actions</a></li>
<li class="toctree-l4"><a class="reference internal" href="#training-the-network">Training the network</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dueling_dqn.html">Dueling DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/mmc.html">Mixed Monte Carlo</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/n_step.html">N-Step Q Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/naf.html">Normalized Advantage Functions</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/nec.html">Neural Episodic Control</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/pal.html">Persistent Advantage Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="pg.html">Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="ppo.html">Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/rainbow.html">Rainbow</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/qr_dqn.html">Quantile Regression DQN</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../architectures/index.html">Architectures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../data_stores/index.html">Data Stores</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../environments/index.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../exploration_policies/index.html">Exploration Policies</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../filters/index.html">Filters</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../memories/index.html">Memories</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../memory_backends/index.html">Memory Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../orchestrators/index.html">Orchestrators</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../core_types.html">Core Types</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../spaces.html">Spaces</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../additional_parameters.html">Additional Parameters</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">Reinforcement Learning Coach</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html">Docs</a> &raquo;</li>
<li><a href="../index.html">Agents</a> &raquo;</li>
<li>Soft Actor-Critic</li>
<li class="wy-breadcrumbs-aside">
<a href="../../../_sources/components/agents/policy_optimization/sac.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="section" id="soft-actor-critic">
<h1>Soft Actor-Critic<a class="headerlink" href="#soft-actor-critic" title="Permalink to this headline"></a></h1>
<p><strong>Actions space:</strong> Continuous</p>
<p><strong>References:</strong> <a class="reference external" href="https://arxiv.org/abs/1801.01290">Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor</a></p>
<div class="section" id="network-structure">
<h2>Network Structure<a class="headerlink" href="#network-structure" title="Permalink to this headline"></a></h2>
<img alt="../../../_images/sac.png" class="align-center" src="../../../_images/sac.png" />
</div>
<div class="section" id="algorithm-description">
<h2>Algorithm Description<a class="headerlink" href="#algorithm-description" title="Permalink to this headline"></a></h2>
<div class="section" id="choosing-an-action-continuous-actions">
<h3>Choosing an action - Continuous actions<a class="headerlink" href="#choosing-an-action-continuous-actions" title="Permalink to this headline"></a></h3>
<p>The policy network is used in order to predict mean and log std for each action. While training, a sample is taken
from a Gaussian distribution with these mean and std values. When testing, the agent can choose deterministically
by picking the mean value or sample from a gaussian distribution like in training.</p>
</div>
<div class="section" id="training-the-network">
<h3>Training the network<a class="headerlink" href="#training-the-network" title="Permalink to this headline"></a></h3>
<p>Start by sampling a batch <span class="math notranslate nohighlight">\(B\)</span> of transitions from the experience replay.</p>
<ul>
<li><p class="first">To train the <strong>Q network</strong>, use the following targets:</p>
<div class="math notranslate nohighlight">
\[y_t^Q=r(s_t,a_t)+\gamma \cdot V(s_{t+1})\]</div>
<p>The state value used in the above target is acquired by running the target state value network.</p>
</li>
<li><p class="first">To train the <strong>State Value network</strong>, use the following targets:</p>
<div class="math notranslate nohighlight">
\[y_t^V = \min_{i=1,2}Q_i(s_t,\tilde{a}) - log\pi (\tilde{a} \vert s),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)\]</div>
<p>The state value network is trained using a sample-based approximation of the connection between and state value and state
action values, The actions used for constructing the target are <strong>not</strong> sampled from the replay buffer, but rather sampled
from the current policy.</p>
</li>
<li><p class="first">To train the <strong>actor network</strong>, use the following equation:</p>
<div class="math notranslate nohighlight">
\[\nabla_{\theta} J \approx \nabla_{\theta} \frac{1}{\vert B \vert} \sum_{s_t\in B} \left( Q \left(s_t, \tilde{a}_\theta(s_t)\right) - log\pi_{\theta}(\tilde{a}_{\theta}(s_t)\vert s_t) \right),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)\]</div>
</li>
</ul>
<p>After every training step, do a soft update of the V target networks weights from the online networks.</p>
<dl class="class">
<dt id="rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.soft_actor_critic_agent.</code><code class="descname">SoftActorCriticAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/soft_actor_critic_agent.html#SoftActorCriticAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>num_steps_between_copying_online_weights_to_target</strong> (StepMethod)
The number of steps between copying the online network weights to the target network weights.</li>
<li><strong>rate_for_copying_weights_to_target</strong> (float)
When copying the online network weights to the target network weights, a soft update will be used, which
weight the new online network weights by rate_for_copying_weights_to_target. (Tau as defined in the paper)</li>
<li><strong>use_deterministic_for_evaluation</strong> (bool)
If True, during the evaluation phase, action are chosen deterministically according to the policy mean
and not sampled from the policy distribution.</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
</div>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../other/dfp.html" class="btn btn-neutral float-right" title="Direct Future Prediction" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="ddpg.html" class="btn btn-neutral" title="Deep Deterministic Policy Gradient" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2018, Intel AI Lab
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>
@@ -123,6 +123,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
@@ -121,6 +121,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Double DQN</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Deep Q Networks</a><ul>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
@@ -114,6 +114,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
+2 -1
View File
@@ -193,7 +193,7 @@ own components under a dedicated directory. For example, tensorflow components w
parts that are implemented using TensorFlow.</p>
<dl class="class">
<dt id="rl_coach.base_parameters.NetworkParameters">
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">NetworkParameters</code><span class="sig-paren">(</span><em>force_cpu=False</em>, <em>async_training=False</em>, <em>shared_optimizer=True</em>, <em>scale_down_gradients_by_number_of_workers_for_sync_training=True</em>, <em>clip_gradients=None</em>, <em>gradients_clipping_method=&lt;GradientClippingMethod.ClipByGlobalNorm: 0&gt;</em>, <em>l2_regularization=0</em>, <em>learning_rate=0.00025</em>, <em>learning_rate_decay_rate=0</em>, <em>learning_rate_decay_steps=0</em>, <em>input_embedders_parameters={}</em>, <em>embedding_merger_type=&lt;EmbeddingMergerType.Concat: 0&gt;</em>, <em>middleware_parameters=None</em>, <em>heads_parameters=[]</em>, <em>use_separate_networks_per_head=False</em>, <em>optimizer_type='Adam'</em>, <em>optimizer_epsilon=0.0001</em>, <em>adam_optimizer_beta1=0.9</em>, <em>adam_optimizer_beta2=0.99</em>, <em>rms_prop_optimizer_decay=0.9</em>, <em>batch_size=32</em>, <em>replace_mse_with_huber_loss=False</em>, <em>create_target_network=False</em>, <em>tensorflow_support=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#NetworkParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.NetworkParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">NetworkParameters</code><span class="sig-paren">(</span><em>force_cpu=False</em>, <em>async_training=False</em>, <em>shared_optimizer=True</em>, <em>scale_down_gradients_by_number_of_workers_for_sync_training=True</em>, <em>clip_gradients=None</em>, <em>gradients_clipping_method=&lt;GradientClippingMethod.ClipByGlobalNorm: 0&gt;</em>, <em>l2_regularization=0</em>, <em>learning_rate=0.00025</em>, <em>learning_rate_decay_rate=0</em>, <em>learning_rate_decay_steps=0</em>, <em>input_embedders_parameters={}</em>, <em>embedding_merger_type=&lt;EmbeddingMergerType.Concat: 0&gt;</em>, <em>middleware_parameters=None</em>, <em>heads_parameters=[]</em>, <em>use_separate_networks_per_head=False</em>, <em>optimizer_type='Adam'</em>, <em>optimizer_epsilon=0.0001</em>, <em>adam_optimizer_beta1=0.9</em>, <em>adam_optimizer_beta2=0.99</em>, <em>rms_prop_optimizer_decay=0.9</em>, <em>batch_size=32</em>, <em>replace_mse_with_huber_loss=False</em>, <em>create_target_network=False</em>, <em>tensorflow_support=True</em>, <em>softmax_temperature=1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#NetworkParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.NetworkParameters" title="Permalink to this definition"></a></dt>
<dd><table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
@@ -257,6 +257,7 @@ selected for this network.</li>
same weights as the online network. It can then be queried, and its weights can be synced from the
online network at will.</li>
<li><strong>tensorflow_support</strong> A flag which specifies if the network is supported by the TensorFlow framework.</li>
<li><strong>softmax_temperature</strong> If a softmax is present in the network head output, use this temperature</li>
</ul>
</td>
</tr>
+1 -3
View File
@@ -194,7 +194,7 @@
<h2>ActionInfo<a class="headerlink" href="#actioninfo" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.core_types.ActionInfo">
<em class="property">class </em><code class="descclassname">rl_coach.core_types.</code><code class="descname">ActionInfo</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List], all_action_probabilities: float = 0, action_value: float = 0.0, state_value: float = 0.0, max_action_value: float = None, action_intrinsic_reward: float = 0</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#ActionInfo"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.ActionInfo" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="descclassname">rl_coach.core_types.</code><code class="descname">ActionInfo</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List], all_action_probabilities: float = 0, action_value: float = 0.0, state_value: float = 0.0, max_action_value: float = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#ActionInfo"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.ActionInfo" title="Permalink to this definition"></a></dt>
<dd><p>Action info is a class that holds an action and various additional information details about it</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
@@ -208,8 +208,6 @@
<li><strong>max_action_value</strong> in case this is an action that was selected randomly, this is the value of the action
that received the maximum value. if no value is given, the action is assumed to be the
action with the maximum value</li>
<li><strong>action_intrinsic_reward</strong> can contain any intrinsic reward that the agent wants to add to this action
selection</li>
</ul>
</td>
</tr>
+1 -1
View File
@@ -206,7 +206,7 @@
<h3>EpisodicExperienceReplay<a class="headerlink" href="#episodicexperiencereplay" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.episodic.EpisodicExperienceReplay">
<em class="property">class </em><code class="descclassname">rl_coach.memories.episodic.</code><code class="descname">EpisodicExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity</em>, <em>int] = (&lt;MemoryGranularity.Transitions: 0&gt;</em>, <em>1000000)</em>, <em>n_step=-1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_experience_replay.html#EpisodicExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicExperienceReplay" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="descclassname">rl_coach.memories.episodic.</code><code class="descname">EpisodicExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity</em>, <em>int] = (&lt;MemoryGranularity.Transitions: 0&gt;</em>, <em>1000000)</em>, <em>n_step=-1</em>, <em>train_to_eval_ratio: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_experience_replay.html#EpisodicExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicExperienceReplay" title="Permalink to this definition"></a></dt>
<dd><p>A replay buffer that stores episodes of transitions. The additional structure allows performing various
calculations of total return and other values that depend on the sequential behavior of the transitions
in the episode.</p>
+14 -16
View File
@@ -382,26 +382,14 @@
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/exploration_policies/index.html#rl_coach.exploration_policies.e_greedy.EGreedy">EGreedy (class in rl_coach.exploration_policies.e_greedy)</a>
</li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.emulate_act_on_trainer">emulate_act_on_trainer() (rl_coach.agents.agent.Agent method)</a>
<ul>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.emulate_observe_on_trainer">emulate_observe_on_trainer() (rl_coach.agents.agent.Agent method)</a>
<ul>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment">Environment (class in rl_coach.environments.environment)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/core_types.html#rl_coach.core_types.EnvResponse">EnvResponse (class in rl_coach.core_types)</a>
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Episode">Episode (class in rl_coach.core_types)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/memories/index.html#rl_coach.memories.episodic.EpisodicExperienceReplay">EpisodicExperienceReplay (class in rl_coach.memories.episodic)</a>
</li>
<li><a href="components/memories/index.html#rl_coach.memories.episodic.EpisodicHindsightExperienceReplay">EpisodicHindsightExperienceReplay (class in rl_coach.memories.episodic)</a>
@@ -503,6 +491,8 @@
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/spaces.html#rl_coach.spaces.ImageObservationSpace">ImageObservationSpace (class in rl_coach.spaces)</a>
</li>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model">improve_reward_model() (rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Batch.info">info() (rl_coach.core_types.Batch method)</a>
</li>
@@ -738,8 +728,6 @@
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.reset_evaluation_state">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.reset_internal_state">reset_internal_state() (rl_coach.agents.agent.Agent method)</a>
<ul>
@@ -748,6 +736,8 @@
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.reset_internal_state">(rl_coach.environments.environment.Environment method)</a>
</li>
</ul></li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.restore_checkpoint">restore_checkpoint() (rl_coach.agents.agent.Agent method)</a>
<ul>
@@ -762,6 +752,12 @@
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Batch.rewards">rewards() (rl_coach.core_types.Batch method)</a>
</li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.run_off_policy_evaluation">run_off_policy_evaluation() (rl_coach.agents.agent.Agent method)</a>
<ul>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference">run_pre_network_filter_for_inference() (rl_coach.agents.agent.Agent method)</a>
<ul>
@@ -839,6 +835,8 @@
<li><a href="components/core_types.html#rl_coach.core_types.Batch.size">size (rl_coach.core_types.Batch attribute)</a>
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Batch.slice">slice() (rl_coach.core_types.Batch method)</a>
</li>
<li><a href="components/agents/policy_optimization/sac.html#rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters">SoftActorCriticAlgorithmParameters (class in rl_coach.agents.soft_actor_critic_agent)</a>
</li>
<li><a href="components/spaces.html#rl_coach.spaces.Space">Space (class in rl_coach.spaces)</a>
</li>
BIN
View File
Binary file not shown.
+1 -1
View File
File diff suppressed because one or more lines are too long
+8
View File
@@ -372,6 +372,14 @@ $(document).ready(function() {
improve stability it also employs bias correction and trust region optimization techniques.
</span>
</div>
<div class="algorithm continuous off-policy" data-year="201808">
<span class="badge">
<a href="components/agents/policy_optimization/sac.html">SAC</a>
<br>
Soft Actor-Critic is an algorithm which optimizes a stochastic policy in an off-policy way.
One of the key features of SAC is that it solves a maximum entropy reinforcement learning problem.
</span>
</div>
<div class="algorithm continuous off-policy" data-year="201509">
<span class="badge">
<a href="components/agents/policy_optimization/ddpg.html">DDPG</a>
+29 -23
View File
@@ -190,13 +190,15 @@
<em class="property">class </em><code class="descclassname">rl_coach.agents.dqn_agent.</code><code class="descname">DQNAgent</code><span class="sig-paren">(</span><em>agent_parameters</em>, <em>parent: Union[LevelManager</em>, <em>CompositeAgent] = None</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/rl_coach/agents/dqn_agent.html#DQNAgent"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent" title="Permalink to this definition"></a></dt>
<dd><dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.act">
<code class="descname">act</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition"></a></dt>
<code class="descname">act</code><span class="sig-paren">(</span><em>action: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List] = None</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition"></a></dt>
<dd><p>Given the agents current knowledge, decide on the next action to apply to the environment</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>action</strong> An action to take, overriding whatever the current policy is</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
</tr>
</tbody>
</table>
@@ -267,26 +269,6 @@ for creating the network.</p>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer">
<code class="descname">emulate_act_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer" title="Permalink to this definition"></a></dt>
<dd><p>This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given the agents current knowledge, decide on the next action to apply to the environment
:return: an action and a dictionary containing any additional info from the action decision process</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer">
<code class="descname">emulate_observe_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer" title="Permalink to this definition"></a></dt>
<dd><p>This emulates the observe using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given a response from the environment, distill the observation from it and store it for later use.
The response should be a dictionary containing the performed action, the new observation and measurements,
the reward, a game over flag and any additional information necessary.
:return:</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.get_predictions">
<code class="descname">get_predictions</code><span class="sig-paren">(</span><em>states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.get_predictions" title="Permalink to this definition"></a></dt>
@@ -342,6 +324,22 @@ This function is called right after each episode is ended.</p>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model">
<code class="descname">improve_reward_model</code><span class="sig-paren">(</span><em>epochs: int</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model" title="Permalink to this definition"></a></dt>
<dd><p>Train a reward model to be used by the doubly-robust estimator</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>epochs</strong> The total number of epochs to use for training a reward model</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">None</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules">
<code class="descname">init_environment_dependent_modules</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules" title="Permalink to this definition"></a></dt>
@@ -450,7 +448,7 @@ given observation</td>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference">
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.core.multiarray.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.</p>
<table class="docutils field-list" frame="void" rules="none">
@@ -542,6 +540,14 @@ by val, and by the current phase set in self.phase.</p>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation">
<code class="descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation" title="Permalink to this definition"></a></dt>
<dd><p>Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on
an evaluation dataset, which was collected by another policy(ies).
:return: None</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference">
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference" title="Permalink to this definition"></a></dt>
+2 -1
View File
@@ -278,7 +278,8 @@ To do so, you should select an environment type and level through the command li
</dl>
</li>
</ol>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">coach</span> <span class="o">-</span><span class="n">p</span> <span class="n">Doom_Basic_BC</span> <span class="o">-</span><span class="n">cp</span><span class="o">=</span><span class="s1">&#39;agent.load_memory_from_file_path=</span><span class="se">\&quot;</span><span class="s1">&lt;experiment dir&gt;/replay_buffer.p</span><span class="se">\&quot;</span><span class="s1">&#39;</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="kn">import</span> <span class="n">PickledReplayBuffer</span>
<span class="n">coach</span> <span class="o">-</span><span class="n">p</span> <span class="n">Doom_Basic_BC</span> <span class="o">-</span><span class="n">cp</span><span class="o">=</span><span class="s1">&#39;agent.load_memory_from_file_path=PickledReplayBuffer(</span><span class="se">\&quot;</span><span class="s1">&lt;experiment dir&gt;/replay_buffer.p</span><span class="se">\&quot;</span><span class="s1">&#39;</span><span class="p">)</span>
</pre></div>
</div>
</div>