1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-20 06:33:31 +02:00
This commit is contained in:
Gal Leibovich
2019-06-16 11:11:21 +03:00
committed by GitHub
parent 8df3c46756
commit 7eb884c5b2
107 changed files with 2200 additions and 495 deletions
+1 -1
View File
@@ -253,7 +253,7 @@ dashboard
## Supported Algorithms
<img src="img/algorithms.png" alt="Coach Design" style="width: 800px;"/>
<img src="docs_raw/source/_static/img/algorithms.png" alt="Coach Design" style="width: 800px;"/>
+48
View File
@@ -0,0 +1,48 @@
# Twin Delayed DDPG
Each experiment uses 5 seeds and is trained for 1M environment steps.
The parameters used for TD3 are the same parameters as described in the [original paper](https://arxiv.org/pdf/1802.09477.pdf), and [repository](https://github.com/sfujim/TD3).
### Ant TD3 - single worker
```bash
coach -p Mujoco_TD3 -lvl ant
```
<img src="ant.png" alt="Ant TD3" width="800"/>
### Hopper TD3 - single worker
```bash
coach -p Mujoco_TD3 -lvl hopper
```
<img src="hopper.png" alt="Hopper TD3" width="800"/>
### Half Cheetah TD3 - single worker
```bash
coach -p Mujoco_TD3 -lvl half_cheetah
```
<img src="half_cheetah.png" alt="Half Cheetah TD3" width="800"/>
### Reacher TD3 - single worker
```bash
coach -p Mujoco_TD3 -lvl reacher
```
<img src="reacher.png" alt="Reacher TD3" width="800"/>
### Walker2D TD3 - single worker
```bash
coach -p Mujoco_TD3 -lvl walker2d
```
<img src="walker2d.png" alt="Walker2D TD3" width="800"/>
Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

After

Width:  |  Height:  |  Size: 60 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

+1
View File
@@ -200,6 +200,7 @@
<li><a href="rl_coach/agents/qr_dqn_agent.html">rl_coach.agents.qr_dqn_agent</a></li>
<li><a href="rl_coach/agents/rainbow_dqn_agent.html">rl_coach.agents.rainbow_dqn_agent</a></li>
<li><a href="rl_coach/agents/soft_actor_critic_agent.html">rl_coach.agents.soft_actor_critic_agent</a></li>
<li><a href="rl_coach/agents/td3_agent.html">rl_coach.agents.td3_agent</a></li>
<li><a href="rl_coach/agents/value_optimization_agent.html">rl_coach.agents.value_optimization_agent</a></li>
<li><a href="rl_coach/architectures/architecture.html">rl_coach.architectures.architecture</a></li>
<li><a href="rl_coach/architectures/network_wrapper.html">rl_coach.architectures.network_wrapper</a></li>
+45 -15
View File
@@ -278,19 +278,6 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">memory_backend_params</span><span class="o">.</span><span class="n">run_type</span> <span class="o">!=</span> <span class="s1">&#39;trainer&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">set_memory_backend</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="p">)</span>
<span class="k">if</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">PickledReplayBuffer</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a pickled replay buffer. Pickled file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_pickled</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">CsvDataset</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a replay buffer from a CSV file. CSV file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_csv</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Trying to load a replay buffer using an unsupported method - </span><span class="si">{}</span><span class="s1">. &#39;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_memory</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_chief</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shared_memory_scratchpad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_lookup_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">)</span>
@@ -444,7 +431,39 @@
<span class="bp">self</span><span class="o">.</span><span class="n">input_filter</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_filter</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span>
<span class="p">[</span><span class="n">network</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span> <span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">()]</span></div>
<span class="p">[</span><span class="n">network</span><span class="o">.</span><span class="n">set_session</span><span class="p">(</span><span class="n">sess</span><span class="p">)</span> <span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">initialize_session_dependent_components</span><span class="p">()</span></div>
<div class="viewcode-block" id="Agent.initialize_session_dependent_components"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.initialize_session_dependent_components">[docs]</a> <span class="k">def</span> <span class="nf">initialize_session_dependent_components</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Initialize components which require a session as part of their initialization.</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Loading a memory from a CSV file, requires an input filter to filter through the data.</span>
<span class="c1"># The filter needs a session before it can be used.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">load_memory_from_file</span><span class="p">()</span></div>
<div class="viewcode-block" id="Agent.load_memory_from_file"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.load_memory_from_file">[docs]</a> <span class="k">def</span> <span class="nf">load_memory_from_file</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Load memory transitions from a file.</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">PickledReplayBuffer</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a pickled replay buffer. Pickled file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_pickled</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">CsvDataset</span><span class="p">):</span>
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">&quot;Loading a replay buffer from a CSV file. CSV file path: </span><span class="si">{}</span><span class="s2">&quot;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_filter</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Trying to load a replay buffer using an unsupported method - </span><span class="si">{}</span><span class="s1">. &#39;</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">))</span></div>
<div class="viewcode-block" id="Agent.register_signal"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.register_signal">[docs]</a> <span class="k">def</span> <span class="nf">register_signal</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">signal_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dump_one_value_per_episode</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">dump_one_value_per_step</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Signal</span><span class="p">:</span>
@@ -868,6 +887,9 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_train</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_batch_rl_training</span><span class="p">:</span>
<span class="c1"># when training an agent for generating a dataset in batch-rl, we don&#39;t want it to be counted as part of</span>
<span class="c1"># the training epochs. we only care for training epochs in batch-rl anyway.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
@@ -1229,7 +1251,15 @@
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">TrainingIteration</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span><span class="p">,</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">EnvironmentSteps</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span><span class="p">,</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">WallClockTime</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">get_current_wall_clock_time</span><span class="p">(),</span>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="p">]</span></div>
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="p">]</span>
<div class="viewcode-block" id="Agent.freeze_memory"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.freeze_memory">[docs]</a> <span class="k">def</span> <span class="nf">freeze_memory</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Shuffle episodes in the memory and freeze it to make sure that no extra data is being pushed anymore.</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;shuffle_episodes&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;freeze&#39;</span><span class="p">)</span></div></div>
</pre></div>
</div>
@@ -196,7 +196,6 @@
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
@@ -266,13 +265,22 @@
<span class="c1"># prediction&#39;s format is (batch,actions,atoms)</span>
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
<span class="n">q_values</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
<span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">,</span>
<span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">q_values</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">q_values</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">q_values</span>
<span class="k">def</span> <span class="nf">get_all_q_values_for_states_and_softmax_probabilities</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
<span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">q_values</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">]</span>
<span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span>
<span class="k">return</span> <span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span>
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
@@ -206,7 +206,7 @@
<span class="kn">from</span> <span class="nn">rl_coach.agents.actor_critic_agent</span> <span class="k">import</span> <span class="n">ActorCriticAgent</span>
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">DDPGActorHeadParameters</span><span class="p">,</span> <span class="n">VHeadParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">DDPGActorHeadParameters</span><span class="p">,</span> <span class="n">DDPGVHeadParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> \
<span class="n">AgentParameters</span><span class="p">,</span> <span class="n">EmbedderScheme</span>
@@ -222,14 +222,17 @@
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="s1">&#39;action&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">EmbedderScheme</span><span class="o">.</span><span class="n">Shallow</span><span class="p">)}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">VHeadParameters</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGVHeadParameters</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span>
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta2</span> <span class="o">=</span> <span class="mf">0.999</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_epsilon</span> <span class="o">=</span> <span class="mf">1e-8</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shared_optimizer</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">scale_down_gradients_by_number_of_workers_for_sync_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># self.l2_regularization = 1e-2</span>
<span class="k">class</span> <span class="nc">DDPGActorNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
@@ -240,6 +243,8 @@
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGActorHeadParameters</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta2</span> <span class="o">=</span> <span class="mf">0.999</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_epsilon</span> <span class="o">=</span> <span class="mf">1e-8</span>
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0001</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
@@ -323,7 +328,7 @@
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">next_actions</span>
<span class="n">q_st_plus_1</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">)</span>
<span class="n">q_st_plus_1</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># calculate the bootstrapped TD targets while discounting terminal states according to</span>
<span class="c1"># use_non_zero_discount_for_terminal_states</span>
@@ -343,7 +348,7 @@
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">actions_mean</span>
<span class="n">action_gradients</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">,</span>
<span class="n">outputs</span><span class="o">=</span><span class="n">critic</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">gradients_wrt_inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="s1">&#39;action&#39;</span><span class="p">])</span>
<span class="n">outputs</span><span class="o">=</span><span class="n">critic</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">gradients_wrt_inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="s1">&#39;action&#39;</span><span class="p">])</span>
<span class="c1"># train the critic</span>
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
+1 -1
View File
@@ -365,7 +365,7 @@
<span class="n">action_values</span> <span class="o">=</span> <span class="kc">None</span>
<span class="c1"># choose action according to the exploration policy and the current phase (evaluating or training the agent)</span>
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="n">action</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="k">if</span> <span class="n">action_values</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">action_values</span> <span class="o">=</span> <span class="n">action_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
@@ -232,6 +232,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span>
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">class</span> <span class="nc">DQNAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
+20 -5
View File
@@ -199,7 +199,7 @@
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">pickle</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
@@ -223,6 +223,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DNDQHeadParameters</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span> <span class="o">=</span> <span class="kc">False</span>
<div class="viewcode-block" id="NECAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/value_optimization/nec.html#rl_coach.agents.nec_agent.NECAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">NECAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
@@ -349,11 +350,25 @@
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">act</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">additional_outputs</span><span class="p">:</span> <span class="n">List</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="c1"># we need to store the state embeddings regardless if the action is random or not</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction_and_update_embeddings</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">get_all_q_values_for_states_and_softmax_probabilities</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
<span class="c1"># get the actions q values and the state embedding</span>
<span class="n">embedding</span><span class="p">,</span> <span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">&#39;main&#39;</span><span class="p">),</span>
<span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">state_embedding</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">output</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">]</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="c1"># store the state embedding for inserting it to the DND later</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span>
<span class="k">def</span> <span class="nf">get_prediction_and_update_embeddings</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
<span class="c1"># get the actions q values and the state embedding</span>
<span class="n">embedding</span><span class="p">,</span> <span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">&#39;main&#39;</span><span class="p">),</span>
@@ -362,7 +377,7 @@
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="c1"># store the state embedding for inserting it to the DND later</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">actions_q_values</span>
@@ -196,7 +196,7 @@
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">copy</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
@@ -262,6 +262,17 @@
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">actions_q_values</span>
<span class="c1"># prediction&#39;s format is (batch,actions,atoms)</span>
<span class="k">def</span> <span class="nf">get_all_q_values_for_states_and_softmax_probabilities</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
<span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">)</span>
<span class="n">quantile_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_q_values</span><span class="p">(</span><span class="n">quantile_values</span><span class="p">)</span>
<span class="k">return</span> <span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span>
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
@@ -0,0 +1,448 @@
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>rl_coach.agents.td3_agent &mdash; Reinforcement Learning Coach 0.12.0 documentation</title>
<script type="text/javascript" src="../../../_static/js/modernizr.min.js"></script>
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Intro</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
</ul>
<p class="caption"><span class="caption-text">Design</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
</ul>
<p class="caption"><span class="caption-text">Contributing</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
</ul>
<p class="caption"><span class="caption-text">Components</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/data_stores/index.html">Data Stores</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/memory_backends/index.html">Memory Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/orchestrators/index.html">Orchestrators</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">Reinforcement Learning Coach</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html">Docs</a> &raquo;</li>
<li><a href="../../index.html">Module code</a> &raquo;</li>
<li>rl_coach.agents.td3_agent</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<h1>Source code for rl_coach.agents.td3_agent</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1"># Copyright (c) 2019 Intel Corporation</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">import</span> <span class="nn">copy</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
<span class="kn">from</span> <span class="nn">rl_coach.agents.ddpg_agent</span> <span class="k">import</span> <span class="n">DDPGAgent</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">DDPGActorHeadParameters</span><span class="p">,</span> <span class="n">TD3VHeadParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> \
<span class="n">AgentParameters</span><span class="p">,</span> <span class="n">EmbedderScheme</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">TrainingSteps</span><span class="p">,</span> <span class="n">Transition</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.additive_noise</span> <span class="k">import</span> <span class="n">AdditiveNoiseParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">BoxActionSpace</span><span class="p">,</span> <span class="n">GoalsSpace</span>
<span class="k">class</span> <span class="nc">TD3CriticNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_q_networks</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(),</span>
<span class="s1">&#39;action&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">EmbedderScheme</span><span class="o">.</span><span class="n">Shallow</span><span class="p">)}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">num_streams</span><span class="o">=</span><span class="n">num_q_networks</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">TD3VHeadParameters</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta2</span> <span class="o">=</span> <span class="mf">0.999</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_epsilon</span> <span class="o">=</span> <span class="mf">1e-8</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">100</span>
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shared_optimizer</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">scale_down_gradients_by_number_of_workers_for_sync_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">class</span> <span class="nc">TD3ActorNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">()}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGActorHeadParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">)]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta2</span> <span class="o">=</span> <span class="mf">0.999</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_epsilon</span> <span class="o">=</span> <span class="mf">1e-8</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">100</span>
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span>
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shared_optimizer</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">scale_down_gradients_by_number_of_workers_for_sync_training</span> <span class="o">=</span> <span class="kc">False</span>
<div class="viewcode-block" id="TD3AlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/td3.html#rl_coach.agents.td3_agent.TD3AlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">TD3AlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param num_steps_between_copying_online_weights_to_target: (StepMethod)</span>
<span class="sd"> The number of steps between copying the online network weights to the target network weights.</span>
<span class="sd"> :param rate_for_copying_weights_to_target: (float)</span>
<span class="sd"> When copying the online network weights to the target network weights, a soft update will be used, which</span>
<span class="sd"> weight the new online network weights by rate_for_copying_weights_to_target</span>
<span class="sd"> :param num_consecutive_playing_steps: (StepMethod)</span>
<span class="sd"> The number of consecutive steps to act between every two training iterations</span>
<span class="sd"> :param use_target_network_for_evaluation: (bool)</span>
<span class="sd"> If set to True, the target network will be used for predicting the actions when choosing actions to act.</span>
<span class="sd"> Since the target network weights change more slowly, the predicted actions will be more consistent.</span>
<span class="sd"> :param action_penalty: (float)</span>
<span class="sd"> The amount by which to penalize the network on high action feature (pre-activation) values.</span>
<span class="sd"> This can prevent the actions features from saturating the TanH activation function, and therefore prevent the</span>
<span class="sd"> gradients from becoming very low.</span>
<span class="sd"> :param clip_critic_targets: (Tuple[float, float] or None)</span>
<span class="sd"> The range to clip the critic target to in order to prevent overestimation of the action values.</span>
<span class="sd"> :param use_non_zero_discount_for_terminal_states: (bool)</span>
<span class="sd"> If set to True, the discount factor will be used for terminal states to bootstrap the next predicted state</span>
<span class="sd"> values. If set to False, the terminal states reward will be taken as the target return for the network.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.005</span>
<span class="bp">self</span><span class="o">.</span><span class="n">use_target_network_for_evaluation</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action_penalty</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clip_critic_targets</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># expected to be a tuple of the form (min_clip_value, max_clip_value) or None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">use_non_zero_discount_for_terminal_states</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">act_for_full_episodes</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_policy_every_x_episode_steps</span> <span class="o">=</span> <span class="mi">2</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_copying_online_weights_to_target</span> <span class="o">=</span> <span class="n">TrainingSteps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">update_policy_every_x_episode_steps</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">policy_noise</span> <span class="o">=</span> <span class="mf">0.2</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_clipping</span> <span class="o">=</span> <span class="mf">0.5</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_q_networks</span> <span class="o">=</span> <span class="mi">2</span></div>
<span class="k">class</span> <span class="nc">TD3AgentExplorationParameters</span><span class="p">(</span><span class="n">AdditiveNoiseParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">class</span> <span class="nc">TD3AgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">td3_algorithm_params</span> <span class="o">=</span> <span class="n">TD3AlgorithmParameters</span><span class="p">()</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">td3_algorithm_params</span><span class="p">,</span>
<span class="n">exploration</span><span class="o">=</span><span class="n">TD3AgentExplorationParameters</span><span class="p">(),</span>
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([(</span><span class="s2">&quot;actor&quot;</span><span class="p">,</span> <span class="n">TD3ActorNetworkParameters</span><span class="p">()),</span>
<span class="p">(</span><span class="s2">&quot;critic&quot;</span><span class="p">,</span>
<span class="n">TD3CriticNetworkParameters</span><span class="p">(</span><span class="n">td3_algorithm_params</span><span class="o">.</span><span class="n">num_q_networks</span><span class="p">))]))</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="s1">&#39;rl_coach.agents.td3_agent:TD3Agent&#39;</span>
<span class="c1"># Twin Delayed DDPG - https://arxiv.org/pdf/1802.09477.pdf</span>
<span class="k">class</span> <span class="nc">TD3Agent</span><span class="p">(</span><span class="n">DDPGAgent</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">&#39;LevelManager&#39;</span><span class="p">,</span> <span class="s1">&#39;CompositeAgent&#39;</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">&quot;Q&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">TD_targets_signal</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">&quot;TD targets&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">action_signal</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">&quot;actions&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="n">actor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;actor&#39;</span><span class="p">]</span>
<span class="n">critic</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;critic&#39;</span><span class="p">]</span>
<span class="n">actor_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;actor&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="n">critic_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;critic&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="c1"># TD error = r + discount*max(q_st_plus_1) - q_st</span>
<span class="n">next_actions</span><span class="p">,</span> <span class="n">actions_mean</span> <span class="o">=</span> <span class="n">actor</span><span class="o">.</span><span class="n">parallel_prediction</span><span class="p">([</span>
<span class="p">(</span><span class="n">actor</span><span class="o">.</span><span class="n">target_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">actor_keys</span><span class="p">)),</span>
<span class="p">(</span><span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">actor_keys</span><span class="p">))</span>
<span class="p">])</span>
<span class="c1"># add noise to the next_actions</span>
<span class="n">noise</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">policy_noise</span><span class="p">,</span> <span class="n">next_actions</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span>
<span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">noise_clipping</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">noise_clipping</span><span class="p">)</span>
<span class="n">next_actions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">clip_action_to_space</span><span class="p">(</span><span class="n">next_actions</span> <span class="o">+</span> <span class="n">noise</span><span class="p">)</span>
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">next_actions</span>
<span class="n">q_st_plus_1</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">)[</span><span class="mi">2</span><span class="p">]</span> <span class="c1"># output #2 is the min (Q1, Q2)</span>
<span class="c1"># calculate the bootstrapped TD targets while discounting terminal states according to</span>
<span class="c1"># use_non_zero_discount_for_terminal_states</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">use_non_zero_discount_for_terminal_states</span><span class="p">:</span>
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">q_st_plus_1</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> \
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">q_st_plus_1</span>
<span class="c1"># clip the TD targets to prevent overestimation errors</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clip_critic_targets</span><span class="p">:</span>
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clip_critic_targets</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">TD_targets_signal</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">)</span>
<span class="c1"># train the critic</span>
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">,</span> <span class="n">TD_targets</span><span class="p">)</span>
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_policy_every_x_episode_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># get the gradients of output #3 (=mean of Q1 network) w.r.t the action</span>
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">actions_mean</span>
<span class="n">action_gradients</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">,</span>
<span class="n">outputs</span><span class="o">=</span><span class="n">critic</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">gradients_wrt_inputs</span><span class="p">[</span><span class="mi">3</span><span class="p">][</span><span class="s1">&#39;action&#39;</span><span class="p">])</span>
<span class="c1"># apply the gradients from the critic to the actor</span>
<span class="n">initial_feed_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">gradients_weights_ph</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span> <span class="o">-</span><span class="n">action_gradients</span><span class="p">}</span>
<span class="n">gradients</span> <span class="o">=</span> <span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">actor_keys</span><span class="p">),</span>
<span class="n">outputs</span><span class="o">=</span><span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">weighted_gradients</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="n">initial_feed_dict</span><span class="o">=</span><span class="n">initial_feed_dict</span><span class="p">)</span>
<span class="k">if</span> <span class="n">actor</span><span class="o">.</span><span class="n">has_global</span><span class="p">:</span>
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
<span class="n">actor</span><span class="o">.</span><span class="n">update_online_network</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_online_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_steps_counter</span>
<span class="k">return</span> <span class="n">Agent</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">update_transition_before_adding_to_replay_buffer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">:</span> <span class="n">Transition</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Transition</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Allows agents to update the transition just before adding it to the replay buffer.</span>
<span class="sd"> Can be useful for agents that want to tweak the reward, termination signal, etc.</span>
<span class="sd"> :param transition: the transition to update</span>
<span class="sd"> :return: the updated transition</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">transition</span><span class="o">.</span><span class="n">game_over</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_steps_counter</span> <span class="o">==</span>\
<span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">environment</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">_max_episode_steps</span>\
<span class="k">else</span> <span class="n">transition</span><span class="o">.</span><span class="n">game_over</span>
<span class="k">return</span> <span class="n">transition</span>
</pre></div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2018-2019, Intel AI Lab
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>
@@ -197,7 +197,7 @@
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
@@ -207,7 +207,8 @@
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.prioritized_experience_replay</span> <span class="k">import</span> <span class="n">PrioritizedExperienceReplay</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span>
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">deepcopy</span>
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">deepcopy</span><span class="p">,</span> <span class="n">copy</span>
<span class="c1">## This is an abstract agent - there is no learn_from_batch method ##</span>
@@ -218,6 +219,12 @@
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">&quot;Q&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span> <span class="o">=</span> <span class="p">{}</span>
<span class="c1"># currently we use softmax action probabilities only in batch-rl,</span>
<span class="c1"># but we might want to extend this later at some point.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span> <span class="o">=</span> \
<span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">],</span> <span class="s1">&#39;should_get_softmax_probabilities&#39;</span><span class="p">)</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span>
<span class="k">def</span> <span class="nf">init_environment_dependent_modules</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_environment_dependent_modules</span><span class="p">()</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">):</span>
@@ -228,12 +235,21 @@
<span class="c1"># Algorithms for which q_values are calculated from predictions will override this function</span>
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">actions_q_values</span>
<span class="k">def</span> <span class="nf">get_all_q_values_for_states_and_softmax_probabilities</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
<span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">)</span>
<span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span>
<span class="k">return</span> <span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span>
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">&#39;main&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">&#39;main&#39;</span><span class="p">),</span>
<span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span>
@@ -255,10 +271,19 @@
<span class="p">)</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">policy</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span><span class="p">:</span>
<span class="n">actions_q_values</span><span class="p">,</span> <span class="n">softmax_probabilities</span> <span class="o">=</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">get_all_q_values_for_states_and_softmax_probabilities</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_q_values_for_states</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
<span class="c1"># choose action according to the exploration policy and the current phase (evaluating or training the agent)</span>
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
<span class="n">action</span><span class="p">,</span> <span class="n">action_probabilities</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">should_get_softmax_probabilities</span> <span class="ow">and</span> <span class="n">softmax_probabilities</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># override the exploration policy&#39;s generated probabilities when an action was taken</span>
<span class="c1"># with the agent&#39;s actual policy</span>
<span class="n">action_probabilities</span> <span class="o">=</span> <span class="n">softmax_probabilities</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_validate_action</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="p">,</span> <span class="n">action</span><span class="p">)</span>
<span class="k">if</span> <span class="n">actions_q_values</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
@@ -270,15 +295,18 @@
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="n">action_probabilities</span> <span class="o">=</span> <span class="n">action_probabilities</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">q_value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q_value</span><span class="p">)</span>
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span>
<span class="n">action_value</span><span class="o">=</span><span class="n">actions_q_values</span><span class="p">[</span><span class="n">action</span><span class="p">],</span>
<span class="n">max_action_value</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">))</span>
<span class="n">max_action_value</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">),</span>
<span class="n">all_action_probabilities</span><span class="o">=</span><span class="n">action_probabilities</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">)</span>
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span> <span class="n">all_action_probabilities</span><span class="o">=</span><span class="n">action_probabilities</span><span class="p">)</span>
<span class="k">return</span> <span class="n">action_info</span>
+7 -2
View File
@@ -182,6 +182,7 @@
<h1>Source code for rl_coach.base_parameters</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1">#</span>
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
@@ -405,7 +406,8 @@
<span class="n">reward_test_level</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">test_using_a_trace_test</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">trace_test_levels</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">trace_max_env_steps</span><span class="o">=</span><span class="mi">5000</span><span class="p">):</span>
<span class="n">trace_max_env_steps</span><span class="o">=</span><span class="mi">5000</span><span class="p">,</span>
<span class="n">read_csv_tries</span><span class="o">=</span><span class="mi">200</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param test:</span>
<span class="sd"> A flag which specifies if the preset should be tested as part of the validation process.</span>
@@ -428,6 +430,8 @@
<span class="sd"> :param trace_max_env_steps:</span>
<span class="sd"> An integer representing the maximum number of environment steps to run when running this preset as part</span>
<span class="sd"> of the trace tests suite.</span>
<span class="sd"> :param read_csv_tries:</span>
<span class="sd"> The number of retries to attempt for reading the experiment csv file, before declaring failure.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
@@ -443,7 +447,8 @@
<span class="bp">self</span><span class="o">.</span><span class="n">reward_test_level</span> <span class="o">=</span> <span class="n">reward_test_level</span>
<span class="bp">self</span><span class="o">.</span><span class="n">test_using_a_trace_test</span> <span class="o">=</span> <span class="n">test_using_a_trace_test</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trace_test_levels</span> <span class="o">=</span> <span class="n">trace_test_levels</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trace_max_env_steps</span> <span class="o">=</span> <span class="n">trace_max_env_steps</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">trace_max_env_steps</span> <span class="o">=</span> <span class="n">trace_max_env_steps</span>
<span class="bp">self</span><span class="o">.</span><span class="n">read_csv_tries</span> <span class="o">=</span> <span class="n">read_csv_tries</span></div>
<div class="viewcode-block" id="NetworkParameters"><a class="viewcode-back" href="../../components/architectures/index.html#rl_coach.base_parameters.NetworkParameters">[docs]</a><span class="k">class</span> <span class="nc">NetworkParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
@@ -202,24 +202,27 @@
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">ActionType</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ContinuousActionExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.schedules</span> <span class="k">import</span> <span class="n">Schedule</span><span class="p">,</span> <span class="n">LinearSchedule</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
<span class="c1"># TODO: consider renaming to gaussian sampling</span>
<span class="k">class</span> <span class="nc">AdditiveNoiseParameters</span><span class="p">(</span><span class="n">ExplorationParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">50000</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise_percentage</span> <span class="o">=</span> <span class="mf">0.05</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">50000</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise</span> <span class="o">=</span> <span class="mf">0.05</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span> <span class="o">=</span> <span class="kc">True</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="s1">&#39;rl_coach.exploration_policies.additive_noise:AdditiveNoise&#39;</span>
<div class="viewcode-block" id="AdditiveNoise"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.additive_noise.AdditiveNoise">[docs]</a><span class="k">class</span> <span class="nc">AdditiveNoise</span><span class="p">(</span><span class="n">ExplorationPolicy</span><span class="p">):</span>
<div class="viewcode-block" id="AdditiveNoise"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.additive_noise.AdditiveNoise">[docs]</a><span class="k">class</span> <span class="nc">AdditiveNoise</span><span class="p">(</span><span class="n">ContinuousActionExplorationPolicy</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> AdditiveNoise is an exploration policy intended for continuous action spaces. It takes the action from the agent</span>
<span class="sd"> and adds a Gaussian distributed noise to it. The amount of noise added to the action follows the noise amount that</span>
@@ -228,17 +231,19 @@
<span class="sd"> 2. Specified by the agents action. In case the agents action is a list with 2 values, the 1st one is assumed to</span>
<span class="sd"> be the mean of the action, and 2nd is assumed to be its standard deviation.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_space</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">noise_percentage_schedule</span><span class="p">:</span> <span class="n">Schedule</span><span class="p">,</span>
<span class="n">evaluation_noise_percentage</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_space</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">noise_schedule</span><span class="p">:</span> <span class="n">Schedule</span><span class="p">,</span>
<span class="n">evaluation_noise</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">noise_as_percentage_from_action_space</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param action_space: the action space used by the environment</span>
<span class="sd"> :param noise_percentage_schedule: the schedule for the noise variance percentage relative to the absolute range</span>
<span class="sd"> of the action space</span>
<span class="sd"> :param evaluation_noise_percentage: the noise variance percentage that will be used during evaluation phases</span>
<span class="sd"> :param noise_schedule: the schedule for the noise</span>
<span class="sd"> :param evaluation_noise: the noise variance that will be used during evaluation phases</span>
<span class="sd"> :param noise_as_percentage_from_action_space: a bool deciding whether the noise is absolute or as a percentage</span>
<span class="sd"> from the action space</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">action_space</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span> <span class="o">=</span> <span class="n">noise_percentage_schedule</span>
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise_percentage</span> <span class="o">=</span> <span class="n">evaluation_noise_percentage</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span> <span class="o">=</span> <span class="n">noise_schedule</span>
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise</span> <span class="o">=</span> <span class="n">evaluation_noise</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span> <span class="o">=</span> <span class="n">noise_as_percentage_from_action_space</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Additive noise exploration works only for continuous controls.&quot;</span>
@@ -248,19 +253,20 @@
<span class="ow">or</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="o">-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span> <span class="o">&lt;</span> <span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">)</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span> <span class="o">&lt;</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Additive noise exploration requires bounded actions&quot;</span><span class="p">)</span>
<span class="c1"># TODO: allow working with unbounded actions by defining the noise in terms of range and not percentage</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">ActionType</span><span class="p">:</span>
<span class="c1"># TODO-potential-bug consider separating internally defined stdev and externally defined stdev into 2 policies</span>
<span class="c1"># set the current noise percentage</span>
<span class="c1"># set the current noise</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="n">current_noise_precentage</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise_percentage</span>
<span class="n">current_noise</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">current_noise_precentage</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span><span class="o">.</span><span class="n">current_value</span>
<span class="n">current_noise</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span><span class="o">.</span><span class="n">current_value</span>
<span class="c1"># scale the noise to the action space range</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise_precentage</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">high</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span><span class="p">:</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">high</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise</span>
<span class="c1"># extract the mean values</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_values</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
@@ -272,18 +278,21 @@
<span class="c1"># step the noise schedule</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="c1"># the second element of the list is assumed to be the standard deviation</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_values</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">action_values</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="c1"># add noise to the action means</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="n">action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">action_values_mean</span><span class="p">,</span> <span class="n">action_values_std</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">action</span> <span class="o">=</span> <span class="n">action_values_mean</span>
<span class="k">return</span> <span class="n">action</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">atleast_1d</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_control_param</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span><span class="o">.</span><span class="n">current_value</span></div>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span><span class="o">.</span><span class="n">current_value</span></div>
</pre></div>
</div>
@@ -202,7 +202,7 @@
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">ActionType</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">DiscreteActionExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.schedules</span> <span class="k">import</span> <span class="n">Schedule</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span>
@@ -217,8 +217,7 @@
<span class="k">return</span> <span class="s1">&#39;rl_coach.exploration_policies.boltzmann:Boltzmann&#39;</span>
<div class="viewcode-block" id="Boltzmann"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.boltzmann.Boltzmann">[docs]</a><span class="k">class</span> <span class="nc">Boltzmann</span><span class="p">(</span><span class="n">ExplorationPolicy</span><span class="p">):</span>
<div class="viewcode-block" id="Boltzmann"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.boltzmann.Boltzmann">[docs]</a><span class="k">class</span> <span class="nc">Boltzmann</span><span class="p">(</span><span class="n">DiscreteActionExplorationPolicy</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> The Boltzmann exploration policy is intended for discrete action spaces. It assumes that each of the possible</span>
<span class="sd"> actions has some value assigned to it (such as the Q value), and uses a softmax function to convert these values</span>
@@ -233,7 +232,7 @@
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">action_space</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature_schedule</span> <span class="o">=</span> <span class="n">temperature_schedule</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">ActionType</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="p">(</span><span class="n">ActionType</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="c1"># softmax calculation</span>
@@ -242,7 +241,8 @@
<span class="c1"># make sure probs sum to 1</span>
<span class="n">probabilities</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">probabilities</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="c1"># choose actions according to the probabilities</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">p</span><span class="o">=</span><span class="n">probabilities</span><span class="p">)</span>
<span class="n">action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">p</span><span class="o">=</span><span class="n">probabilities</span><span class="p">)</span>
<span class="k">return</span> <span class="n">action</span><span class="p">,</span> <span class="n">probabilities</span>
<span class="k">def</span> <span class="nf">get_control_param</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature_schedule</span><span class="o">.</span><span class="n">current_value</span></div>
@@ -202,7 +202,7 @@
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">ActionType</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">DiscreteActionExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span>
@@ -212,7 +212,7 @@
<span class="k">return</span> <span class="s1">&#39;rl_coach.exploration_policies.categorical:Categorical&#39;</span>
<div class="viewcode-block" id="Categorical"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.categorical.Categorical">[docs]</a><span class="k">class</span> <span class="nc">Categorical</span><span class="p">(</span><span class="n">ExplorationPolicy</span><span class="p">):</span>
<div class="viewcode-block" id="Categorical"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.categorical.Categorical">[docs]</a><span class="k">class</span> <span class="nc">Categorical</span><span class="p">(</span><span class="n">DiscreteActionExplorationPolicy</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Categorical exploration policy is intended for discrete action spaces. It expects the action values to</span>
<span class="sd"> represent a probability distribution over the action, from which a single action will be sampled.</span>
@@ -225,13 +225,18 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">action_space</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">ActionType</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="p">(</span><span class="n">ActionType</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">:</span>
<span class="c1"># choose actions according to the probabilities</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">actions</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">action_values</span><span class="p">)</span>
<span class="n">action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">actions</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">action_values</span><span class="p">)</span>
<span class="k">return</span> <span class="n">action</span><span class="p">,</span> <span class="n">action_values</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># take the action with the highest probability</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="n">action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="n">one_hot_action_probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">actions</span><span class="p">))</span>
<span class="n">one_hot_action_probabilities</span><span class="p">[</span><span class="n">action</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">action</span><span class="p">,</span> <span class="n">one_hot_action_probabilities</span>
<span class="k">def</span> <span class="nf">get_control_param</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="mi">0</span></div>
@@ -203,8 +203,7 @@
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">ActionType</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.additive_noise</span> <span class="k">import</span> <span class="n">AdditiveNoiseParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationPolicy</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationParameters</span><span class="p">,</span> <span class="n">ExplorationPolicy</span>
<span class="kn">from</span> <span class="nn">rl_coach.schedules</span> <span class="k">import</span> <span class="n">Schedule</span><span class="p">,</span> <span class="n">LinearSchedule</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">dynamic_import_and_instantiate_module_from_params</span>
@@ -216,7 +215,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">epsilon_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mi">50000</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_epsilon</span> <span class="o">=</span> <span class="mf">0.05</span>
<span class="bp">self</span><span class="o">.</span><span class="n">continuous_exploration_policy_parameters</span> <span class="o">=</span> <span class="n">AdditiveNoiseParameters</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">continuous_exploration_policy_parameters</span><span class="o">.</span><span class="n">noise_percentage_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">50000</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">continuous_exploration_policy_parameters</span><span class="o">.</span><span class="n">noise_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">50000</span><span class="p">)</span>
<span class="c1"># for continuous control -</span>
<span class="c1"># (see http://www.cs.ubc.ca/~van/papers/2017-TOG-deepLoco/2017-TOG-deepLoco.pdf)</span>
@@ -265,25 +264,31 @@
<span class="n">epsilon</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">evaluation_epsilon</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">epsilon_schedule</span><span class="o">.</span><span class="n">current_value</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_random_value</span> <span class="o">&gt;=</span> <span class="n">epsilon</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">ActionType</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="p">(</span><span class="n">ActionType</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]):</span>
<span class="n">epsilon</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">evaluation_epsilon</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">epsilon_schedule</span><span class="o">.</span><span class="n">current_value</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">):</span>
<span class="n">top_action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_random_value</span> <span class="o">&lt;</span> <span class="n">epsilon</span><span class="p">:</span>
<span class="n">chosen_action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
<span class="n">probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">actions</span><span class="p">),</span>
<span class="mf">1.</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">high</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">chosen_action</span> <span class="o">=</span> <span class="n">top_action</span>
<span class="n">chosen_action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="c1"># one-hot probabilities vector</span>
<span class="n">probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">actions</span><span class="p">))</span>
<span class="n">probabilities</span><span class="p">[</span><span class="n">chosen_action</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">step_epsilon</span><span class="p">()</span>
<span class="k">return</span> <span class="n">chosen_action</span><span class="p">,</span> <span class="n">probabilities</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_random_value</span> <span class="o">&lt;</span> <span class="n">epsilon</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">:</span>
<span class="n">chosen_action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">chosen_action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">continuous_exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="c1"># step the epsilon schedule and generate a new random value for next time</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">epsilon_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_random_value</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">step_epsilon</span><span class="p">()</span>
<span class="k">return</span> <span class="n">chosen_action</span>
<span class="k">def</span> <span class="nf">get_control_param</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
@@ -295,7 +300,13 @@
<span class="k">def</span> <span class="nf">change_phase</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">change_phase</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">continuous_exploration_policy</span><span class="o">.</span><span class="n">change_phase</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">continuous_exploration_policy</span><span class="o">.</span><span class="n">change_phase</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">step_epsilon</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="c1"># step the epsilon schedule and generate a new random value for next time</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">epsilon_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">current_random_value</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">()</span></div>
</pre></div>
</div>
@@ -201,7 +201,7 @@
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">Parameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">ActionType</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">,</span> <span class="n">GoalsSpace</span>
<span class="k">class</span> <span class="nc">ExplorationParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
@@ -237,14 +237,10 @@
<span class="sd"> Given a list of values corresponding to each action, </span>
<span class="sd"> choose one actions according to the exploration policy</span>
<span class="sd"> :param action_values: A list of action values</span>
<span class="sd"> :return: The chosen action</span>
<span class="sd"> :return: The chosen action,</span>
<span class="sd"> The probability of the action (if available, otherwise 1 for absolute certainty in the action)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">ExplorationPolicy</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The ExplorationPolicy class is an abstract class and should not be used directly. &quot;</span>
<span class="s2">&quot;Please set the exploration parameters to point to an inheriting class like EGreedy or &quot;</span>
<span class="s2">&quot;AdditiveNoise&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The get_action function should be overridden in the inheriting exploration class&quot;</span><span class="p">)</span></div>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span></div>
<div class="viewcode-block" id="ExplorationPolicy.change_phase"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.change_phase">[docs]</a> <span class="k">def</span> <span class="nf">change_phase</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
@@ -265,6 +261,45 @@
<span class="k">def</span> <span class="nf">get_control_param</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="mi">0</span></div>
<span class="k">class</span> <span class="nc">DiscreteActionExplorationPolicy</span><span class="p">(</span><span class="n">ExplorationPolicy</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> A discrete action exploration policy.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_space</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param action_space: the action space used by the environment</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">)</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">action_space</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="p">(</span><span class="n">ActionType</span><span class="p">,</span> <span class="n">List</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Given a list of values corresponding to each action,</span>
<span class="sd"> choose one actions according to the exploration policy</span>
<span class="sd"> :param action_values: A list of action values</span>
<span class="sd"> :return: The chosen action,</span>
<span class="sd"> The probabilities of actions to select from (if not available a one-hot vector)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">ExplorationPolicy</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The ExplorationPolicy class is an abstract class and should not be used directly. &quot;</span>
<span class="s2">&quot;Please set the exploration parameters to point to an inheriting class like EGreedy or &quot;</span>
<span class="s2">&quot;AdditiveNoise&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The get_action function should be overridden in the inheriting exploration class&quot;</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">ContinuousActionExplorationPolicy</span><span class="p">(</span><span class="n">ExplorationPolicy</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> A continuous action exploration policy.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_space</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param action_space: the action space used by the environment</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">GoalsSpace</span><span class="p">)</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">action_space</span><span class="p">)</span>
</pre></div>
</div>
@@ -202,7 +202,7 @@
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionType</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationParameters</span><span class="p">,</span> <span class="n">ExplorationPolicy</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
@@ -224,9 +224,12 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">action_space</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">ActionType</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">]):</span>
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="p">)</span> <span class="o">==</span> <span class="n">DiscreteActionSpace</span><span class="p">:</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="n">action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="n">one_hot_action_probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">actions</span><span class="p">))</span>
<span class="n">one_hot_action_probabilities</span><span class="p">[</span><span class="n">action</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">action</span><span class="p">,</span> <span class="n">one_hot_action_probabilities</span>
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="p">)</span> <span class="o">==</span> <span class="n">BoxActionSpace</span><span class="p">:</span>
<span class="k">return</span> <span class="n">action_values</span>
@@ -202,12 +202,13 @@
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">ActionType</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ContinuousActionExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">,</span> <span class="n">GoalsSpace</span>
<span class="c1"># Based on on the description in:</span>
<span class="c1"># https://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab</span>
<span class="k">class</span> <span class="nc">OUProcessParameters</span><span class="p">(</span><span class="n">ExplorationParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
@@ -222,7 +223,7 @@
<span class="c1"># Ornstein-Uhlenbeck process</span>
<div class="viewcode-block" id="OUProcess"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.ou_process.OUProcess">[docs]</a><span class="k">class</span> <span class="nc">OUProcess</span><span class="p">(</span><span class="n">ExplorationPolicy</span><span class="p">):</span>
<div class="viewcode-block" id="OUProcess"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.ou_process.OUProcess">[docs]</a><span class="k">class</span> <span class="nc">OUProcess</span><span class="p">(</span><span class="n">ContinuousActionExplorationPolicy</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> OUProcess exploration policy is intended for continuous action spaces, and selects the action according to</span>
<span class="sd"> an Ornstein-Uhlenbeck process. The Ornstein-Uhlenbeck process implements the action as a Gaussian process, where</span>
@@ -239,10 +240,6 @@
<span class="bp">self</span><span class="o">.</span><span class="n">state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dt</span> <span class="o">=</span> <span class="n">dt</span>
<span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">GoalsSpace</span><span class="p">)):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;OU process exploration works only for continuous controls.&quot;</span>
<span class="s2">&quot;The given action space is of type: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">action_space</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
@@ -242,9 +242,13 @@
<span class="bp">self</span><span class="o">.</span><span class="n">network_params</span> <span class="o">=</span> <span class="n">network_params</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_replace_network_dense_layers</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">ActionType</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">]):</span>
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="p">)</span> <span class="o">==</span> <span class="n">DiscreteActionSpace</span><span class="p">:</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="n">action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span>
<span class="n">one_hot_action_probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">actions</span><span class="p">))</span>
<span class="n">one_hot_action_probabilities</span><span class="p">[</span><span class="n">action</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">action</span><span class="p">,</span> <span class="n">one_hot_action_probabilities</span>
<span class="k">elif</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="p">)</span> <span class="o">==</span> <span class="n">BoxActionSpace</span><span class="p">:</span>
<span class="n">action_values_mean</span> <span class="o">=</span> <span class="n">action_values</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">action_values</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
@@ -203,7 +203,7 @@
<span class="kn">from</span> <span class="nn">scipy.stats</span> <span class="k">import</span> <span class="n">truncnorm</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">ActionType</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationPolicy</span><span class="p">,</span> <span class="n">ExplorationParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.exploration_policy</span> <span class="k">import</span> <span class="n">ExplorationParameters</span><span class="p">,</span> <span class="n">ContinuousActionExplorationPolicy</span>
<span class="kn">from</span> <span class="nn">rl_coach.schedules</span> <span class="k">import</span> <span class="n">Schedule</span><span class="p">,</span> <span class="n">LinearSchedule</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
@@ -211,17 +211,18 @@
<span class="k">class</span> <span class="nc">TruncatedNormalParameters</span><span class="p">(</span><span class="n">ExplorationParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">50000</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise_percentage</span> <span class="o">=</span> <span class="mf">0.05</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span> <span class="o">=</span> <span class="n">LinearSchedule</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">50000</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise</span> <span class="o">=</span> <span class="mf">0.05</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clip_low</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clip_high</span> <span class="o">=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span> <span class="o">=</span> <span class="kc">True</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="s1">&#39;rl_coach.exploration_policies.truncated_normal:TruncatedNormal&#39;</span>
<div class="viewcode-block" id="TruncatedNormal"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.truncated_normal.TruncatedNormal">[docs]</a><span class="k">class</span> <span class="nc">TruncatedNormal</span><span class="p">(</span><span class="n">ExplorationPolicy</span><span class="p">):</span>
<div class="viewcode-block" id="TruncatedNormal"><a class="viewcode-back" href="../../../components/exploration_policies/index.html#rl_coach.exploration_policies.truncated_normal.TruncatedNormal">[docs]</a><span class="k">class</span> <span class="nc">TruncatedNormal</span><span class="p">(</span><span class="n">ContinuousActionExplorationPolicy</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> The TruncatedNormal exploration policy is intended for continuous action spaces. It samples the action from a</span>
<span class="sd"> normal distribution, where the mean action is given by the agent, and the standard deviation can be given in t</span>
@@ -232,17 +233,20 @@
<span class="sd"> When the sampled action is outside of the action bounds given by the user, it is sampled again and again, until it</span>
<span class="sd"> is within the bounds.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_space</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">noise_percentage_schedule</span><span class="p">:</span> <span class="n">Schedule</span><span class="p">,</span>
<span class="n">evaluation_noise_percentage</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">clip_low</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">clip_high</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_space</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">,</span> <span class="n">noise_schedule</span><span class="p">:</span> <span class="n">Schedule</span><span class="p">,</span>
<span class="n">evaluation_noise</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">clip_low</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">clip_high</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="n">noise_as_percentage_from_action_space</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> :param action_space: the action space used by the environment</span>
<span class="sd"> :param noise_percentage_schedule: the schedule for the noise variance percentage relative to the absolute range</span>
<span class="sd"> of the action space</span>
<span class="sd"> :param evaluation_noise_percentage: the noise variance percentage that will be used during evaluation phases</span>
<span class="sd"> :param noise_schedule: the schedule for the noise variance</span>
<span class="sd"> :param evaluation_noise: the noise variance that will be used during evaluation phases</span>
<span class="sd"> :param noise_as_percentage_from_action_space: whether to consider the noise as a percentage of the action space</span>
<span class="sd"> or absolute value</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">action_space</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span> <span class="o">=</span> <span class="n">noise_percentage_schedule</span>
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise_percentage</span> <span class="o">=</span> <span class="n">evaluation_noise_percentage</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span> <span class="o">=</span> <span class="n">noise_schedule</span>
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise</span> <span class="o">=</span> <span class="n">evaluation_noise</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span> <span class="o">=</span> <span class="n">noise_as_percentage_from_action_space</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clip_low</span> <span class="o">=</span> <span class="n">clip_low</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clip_high</span> <span class="o">=</span> <span class="n">clip_high</span>
@@ -254,17 +258,21 @@
<span class="ow">or</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="o">-</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span> <span class="o">&lt;</span> <span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">)</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span> <span class="o">&lt;</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Additive noise exploration requires bounded actions&quot;</span><span class="p">)</span>
<span class="c1"># TODO: allow working with unbounded actions by defining the noise in terms of range and not percentage</span>
<span class="k">def</span> <span class="nf">get_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action_values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">ActionType</span><span class="p">:</span>
<span class="c1"># set the current noise percentage</span>
<span class="c1"># set the current noise</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="n">current_noise_precentage</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise_percentage</span>
<span class="n">current_noise</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">current_noise_precentage</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span><span class="o">.</span><span class="n">current_value</span>
<span class="n">current_noise</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span><span class="o">.</span><span class="n">current_value</span>
<span class="c1"># scale the noise to the action space range</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise_precentage</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">high</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span><span class="p">:</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">high</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise</span>
<span class="c1"># scale the noise to the action space range</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">high</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">)</span>
<span class="c1"># extract the mean values</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_values</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
@@ -276,7 +284,7 @@
<span class="c1"># step the noise schedule</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="c1"># the second element of the list is assumed to be the standard deviation</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_values</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">action_values</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">action_values</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
@@ -290,7 +298,7 @@
<span class="k">return</span> <span class="n">action</span>
<span class="k">def</span> <span class="nf">get_control_param</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">noise_percentage_schedule</span><span class="o">.</span><span class="n">current_value</span></div>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">noise_schedule</span><span class="o">.</span><span class="n">current_value</span></div>
</pre></div>
</div>
@@ -204,7 +204,7 @@
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ObservationType</span>
<span class="kn">from</span> <span class="nn">rl_coach.filters.observation.observation_filter</span> <span class="k">import</span> <span class="n">ObservationFilter</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ObservationSpace</span>
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">ObservationSpace</span><span class="p">,</span> <span class="n">VectorObservationSpace</span>
<span class="k">class</span> <span class="nc">LazyStack</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
@@ -246,6 +246,7 @@
<span class="bp">self</span><span class="o">.</span><span class="n">stack_size</span> <span class="o">=</span> <span class="n">stack_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stacking_axis</span> <span class="o">=</span> <span class="n">stacking_axis</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stack</span> <span class="o">=</span> <span class="p">[]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_observation_space</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">stack_size</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The stack shape must be a positive number&quot;</span><span class="p">)</span>
@@ -269,7 +270,6 @@
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The stacking axis is larger than the number of dimensions in the observation space&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">observation</span><span class="p">:</span> <span class="n">ObservationType</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ObservationType</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stack</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stack</span> <span class="o">=</span> <span class="n">deque</span><span class="p">([</span><span class="n">observation</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">stack_size</span><span class="p">,</span> <span class="n">maxlen</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">stack_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
@@ -277,9 +277,16 @@
<span class="bp">self</span><span class="o">.</span><span class="n">stack</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">observation</span><span class="p">)</span>
<span class="n">observation</span> <span class="o">=</span> <span class="n">LazyStack</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stack</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">stacking_axis</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_observation_space</span><span class="p">,</span> <span class="n">VectorObservationSpace</span><span class="p">):</span>
<span class="c1"># when stacking vectors, we cannot avoid copying the memory as we&#39;re flattening it all</span>
<span class="n">observation</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">observation</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="k">return</span> <span class="n">observation</span>
<span class="k">def</span> <span class="nf">get_filtered_observation_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_observation_space</span><span class="p">:</span> <span class="n">ObservationSpace</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ObservationSpace</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_observation_space</span><span class="p">,</span> <span class="n">VectorObservationSpace</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_observation_space</span> <span class="o">=</span> <span class="n">input_observation_space</span> <span class="o">=</span> <span class="n">VectorObservationSpace</span><span class="p">(</span><span class="n">input_observation_space</span><span class="o">.</span><span class="n">shape</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">stack_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stacking_axis</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
<span class="n">input_observation_space</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">input_observation_space</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">values</span><span class="o">=</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">stack_size</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
@@ -208,6 +208,7 @@
<span class="kn">import</span> <span class="nn">random</span>
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">Transition</span><span class="p">,</span> <span class="n">Episode</span>
<span class="kn">from</span> <span class="nn">rl_coach.filters.filter</span> <span class="k">import</span> <span class="n">InputFilter</span>
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
<span class="kn">from</span> <span class="nn">rl_coach.memories.memory</span> <span class="k">import</span> <span class="n">Memory</span><span class="p">,</span> <span class="n">MemoryGranularity</span><span class="p">,</span> <span class="n">MemoryParameters</span>
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">ReaderWriterLock</span><span class="p">,</span> <span class="n">ProgressBar</span>
@@ -591,11 +592,12 @@
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
<span class="k">return</span> <span class="n">mean</span>
<span class="k">def</span> <span class="nf">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">csv_dataset</span><span class="p">:</span> <span class="n">CsvDataset</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">csv_dataset</span><span class="p">:</span> <span class="n">CsvDataset</span><span class="p">,</span> <span class="n">input_filter</span><span class="p">:</span> <span class="n">InputFilter</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Restore the replay buffer contents from a csv file.</span>
<span class="sd"> The csv file is assumed to include a list of transitions.</span>
<span class="sd"> :param csv_dataset: A construct which holds the dataset parameters</span>
<span class="sd"> :param input_filter: A filter used to filter the CSV data before feeding it to the memory.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">assert_not_frozen</span><span class="p">()</span>
@@ -612,18 +614,30 @@
<span class="k">for</span> <span class="n">e_id</span> <span class="ow">in</span> <span class="n">episode_ids</span><span class="p">:</span>
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">e_id</span><span class="p">)</span>
<span class="n">df_episode_transitions</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">df</span><span class="p">[</span><span class="s1">&#39;episode_id&#39;</span><span class="p">]</span> <span class="o">==</span> <span class="n">e_id</span><span class="p">]</span>
<span class="n">input_filter</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">df_episode_transitions</span><span class="p">)</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">:</span>
<span class="c1"># we have to have at least 2 rows in each episode for creating a transition</span>
<span class="k">continue</span>
<span class="n">episode</span> <span class="o">=</span> <span class="n">Episode</span><span class="p">()</span>
<span class="n">transitions</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">current_transition</span><span class="p">),</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">next_transition</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">df_episode_transitions</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">(),</span>
<span class="n">df_episode_transitions</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">()):</span>
<span class="n">state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">current_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
<span class="n">next_state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">next_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
<span class="n">episode</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span>
<span class="n">transitions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">Transition</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">state</span><span class="p">},</span>
<span class="n">action</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">],</span> <span class="n">reward</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;reward&#39;</span><span class="p">],</span>
<span class="n">next_state</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">next_state</span><span class="p">},</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">info</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;all_action_probabilities&#39;</span><span class="p">:</span>
<span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;all_action_probabilities&#39;</span><span class="p">])}))</span>
<span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;all_action_probabilities&#39;</span><span class="p">])}),</span>
<span class="p">)</span>
<span class="n">transitions</span> <span class="o">=</span> <span class="n">input_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">transitions</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">transitions</span><span class="p">:</span>
<span class="n">episode</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="c1"># Set the last transition to end the episode</span>
<span class="k">if</span> <span class="n">csv_dataset</span><span class="o">.</span><span class="n">is_episodic</span><span class="p">:</span>
@@ -635,8 +649,6 @@
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episode_ids</span><span class="p">))</span>
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shuffle_episodes</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">freeze</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Freezing the replay buffer does not allow any new transitions to be added to the memory.</span>
@@ -21,6 +21,7 @@ A detailed description of those algorithms can be found by navigating to each of
imitation/cil
policy_optimization/cppo
policy_optimization/ddpg
policy_optimization/td3
policy_optimization/sac
other/dfp
value_optimization/double_dqn
@@ -0,0 +1,55 @@
Twin Delayed Deep Deterministic Policy Gradient
==================================
**Actions space:** Continuous
**References:** `Addressing Function Approximation Error in Actor-Critic Methods <https://arxiv.org/pdf/1802.09477>`_
Network Structure
-----------------
.. image:: /_static/img/design_imgs/td3.png
:align: center
Algorithm Description
---------------------
Choosing an action
++++++++++++++++++
Pass the current states through the actor network, and get an action mean vector :math:`\mu`.
While in training phase, use a continuous exploration policy, such as a small zero-meaned gaussian noise,
to add exploration noise to the action. When testing, use the mean vector :math:`\mu` as-is.
Training the network
++++++++++++++++++++
Start by sampling a batch of transitions from the experience replay.
* To train the two **critic networks**, use the following targets:
:math:`y_t=r(s_t,a_t )+\gamma \cdot \min_{i=1,2} Q_{i}(s_{t+1},\mu(s_{t+1} )+[\mathcal{N}(0,\,\sigma^{2})]^{MAX\_NOISE}_{MIN\_NOISE})`
First run the actor target network, using the next states as the inputs, and get :math:`\mu (s_{t+1} )`. Then, add a
clipped gaussian noise to these actions, and clip the resulting actions to the actions space.
Next, run the critic target networks using the next states and :math:`\mu (s_{t+1} )+[\mathcal{N}(0,\,\sigma^{2})]^{MAX\_NOISE}_{MIN\_NOISE}`,
and use the minimum between the two critic networks predictions in order to calculate :math:`y_t` according to the
equation above. To train the networks, use the current states and actions as the inputs, and :math:`y_t`
as the targets.
* To train the **actor network**, use the following equation:
:math:`\nabla_{\theta^\mu } J \approx E_{s_t \tilde{} \rho^\beta } [\nabla_a Q_{1}(s,a)|_{s=s_t,a=\mu (s_t ) } \cdot \nabla_{\theta^\mu} \mu(s)|_{s=s_t} ]`
Use the actor's online network to get the action mean values using the current states as the inputs.
Then, use the first critic's online network in order to get the gradients of the critic output with respect to the
action mean values :math:`\nabla _a Q_{1}(s,a)|_{s=s_t,a=\mu(s_t ) }`.
Using the chain rule, calculate the gradients of the actor's output, with respect to the actor weights,
given :math:`\nabla_a Q(s,a)`. Finally, apply those gradients to the actor network.
The actor's training is done at a slower frequency than the critic's training, in order to allow the critic to better fit the
current policy, before exercising the critic in order to train the actor.
Following the same, delayed, actor's training cadence, do a soft update of the critic and actor target networks' weights
from the online networks.
.. autoclass:: rl_coach.agents.td3_agent.TD3AlgorithmParameters
@@ -214,6 +214,16 @@ The algorithms are ordered by their release date in descending order.
and therefore it is able to use a replay buffer in order to improve sample efficiency.
</span>
</div>
<div class="algorithm continuous off-policy" data-year="201509">
<span class="badge">
<a href="components/agents/policy_optimization/td3.html">TD3</a>
<br>
Very similar to DDPG, i.e. an actor-critic for continuous action spaces, that uses a replay buffer in
order to improve sample efficiency. TD3 uses two critic networks in order to mitigate the overestimation
in the Q state-action value prediction, slows down the actor updates in order to increase stability and
adds noise to actions while training the critic in order to smooth out the critic's predictions.
</span>
</div>
<div class="algorithm continuous discrete on-policy" data-year="201706">
<span class="badge">
<a href="components/agents/policy_optimization/ppo.html">PPO</a>
+15
View File
@@ -289,6 +289,12 @@ img.align-center, .figure.align-center, object.align-center {
margin-right: auto;
}
img.align-default, .figure.align-default {
display: block;
margin-left: auto;
margin-right: auto;
}
.align-left {
text-align: left;
}
@@ -297,6 +303,10 @@ img.align-center, .figure.align-center, object.align-center {
text-align: center;
}
.align-default {
text-align: center;
}
.align-right {
text-align: right;
}
@@ -368,6 +378,11 @@ table.align-center {
margin-right: auto;
}
table.align-default {
margin-left: auto;
margin-right: auto;
}
table caption span.caption-number {
font-style: italic;
}
+4 -3
View File
@@ -319,12 +319,13 @@ var Search = {
for (var prefix in objects) {
for (var name in objects[prefix]) {
var fullname = (prefix ? prefix + '.' : '') + name;
if (fullname.toLowerCase().indexOf(object) > -1) {
var fullnameLower = fullname.toLowerCase()
if (fullnameLower.indexOf(object) > -1) {
var score = 0;
var parts = fullname.split('.');
var parts = fullnameLower.split('.');
// check for different match types: exact matches of full name or
// "last name" (i.e. last dotted part)
if (fullname == object || parts[parts.length - 1] == object) {
if (fullnameLower == object || parts[parts.length - 1] == object) {
score += Scorer.objNameMatch;
// matches in last name
} else if (parts[parts.length - 1].indexOf(object) > -1) {
+5 -4
View File
@@ -195,7 +195,7 @@
<h2>VisualizationParameters<a class="headerlink" href="#visualizationparameters" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.base_parameters.VisualizationParameters">
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">VisualizationParameters</code><span class="sig-paren">(</span><em>print_networks_summary=False</em>, <em>dump_csv=True</em>, <em>dump_signals_to_csv_every_x_episodes=5</em>, <em>dump_gifs=False</em>, <em>dump_mp4=False</em>, <em>video_dump_methods=None</em>, <em>dump_in_episode_signals=False</em>, <em>dump_parameters_documentation=True</em>, <em>render=False</em>, <em>native_rendering=False</em>, <em>max_fps_for_human_control=10</em>, <em>tensorboard=False</em>, <em>add_rendered_image_to_env_response=False</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#VisualizationParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.VisualizationParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.base_parameters.</code><code class="sig-name descname">VisualizationParameters</code><span class="sig-paren">(</span><em class="sig-param">print_networks_summary=False</em>, <em class="sig-param">dump_csv=True</em>, <em class="sig-param">dump_signals_to_csv_every_x_episodes=5</em>, <em class="sig-param">dump_gifs=False</em>, <em class="sig-param">dump_mp4=False</em>, <em class="sig-param">video_dump_methods=None</em>, <em class="sig-param">dump_in_episode_signals=False</em>, <em class="sig-param">dump_parameters_documentation=True</em>, <em class="sig-param">render=False</em>, <em class="sig-param">native_rendering=False</em>, <em class="sig-param">max_fps_for_human_control=10</em>, <em class="sig-param">tensorboard=False</em>, <em class="sig-param">add_rendered_image_to_env_response=False</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#VisualizationParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.VisualizationParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -244,7 +244,7 @@ which will be passed to the agent and allow using those images.</p></li>
<h2>PresetValidationParameters<a class="headerlink" href="#presetvalidationparameters" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.base_parameters.PresetValidationParameters">
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">PresetValidationParameters</code><span class="sig-paren">(</span><em>test=False</em>, <em>min_reward_threshold=0</em>, <em>max_episodes_to_achieve_reward=1</em>, <em>num_workers=1</em>, <em>reward_test_level=None</em>, <em>test_using_a_trace_test=True</em>, <em>trace_test_levels=None</em>, <em>trace_max_env_steps=5000</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#PresetValidationParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.PresetValidationParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.base_parameters.</code><code class="sig-name descname">PresetValidationParameters</code><span class="sig-paren">(</span><em class="sig-param">test=False</em>, <em class="sig-param">min_reward_threshold=0</em>, <em class="sig-param">max_episodes_to_achieve_reward=1</em>, <em class="sig-param">num_workers=1</em>, <em class="sig-param">reward_test_level=None</em>, <em class="sig-param">test_using_a_trace_test=True</em>, <em class="sig-param">trace_test_levels=None</em>, <em class="sig-param">trace_max_env_steps=5000</em>, <em class="sig-param">read_csv_tries=200</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#PresetValidationParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.PresetValidationParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -261,6 +261,7 @@ reward tests suite.</p></li>
trace tests suite.</p></li>
<li><p><strong>trace_max_env_steps</strong> An integer representing the maximum number of environment steps to run when running this preset as part
of the trace tests suite.</p></li>
<li><p><strong>read_csv_tries</strong> The number of retries to attempt for reading the experiment csv file, before declaring failure.</p></li>
</ul>
</dd>
</dl>
@@ -271,7 +272,7 @@ of the trace tests suite.</p></li>
<h2>TaskParameters<a class="headerlink" href="#taskparameters" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.base_parameters.TaskParameters">
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">TaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks = &lt;Frameworks.tensorflow: 'TensorFlow'&gt;</em>, <em>evaluate_only: int = None</em>, <em>use_cpu: bool = False</em>, <em>experiment_path='/tmp'</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_dir=None</em>, <em>checkpoint_restore_path=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em>, <em>num_gpu: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#TaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.TaskParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.base_parameters.</code><code class="sig-name descname">TaskParameters</code><span class="sig-paren">(</span><em class="sig-param">framework_type: rl_coach.base_parameters.Frameworks = &lt;Frameworks.tensorflow: 'TensorFlow'&gt;</em>, <em class="sig-param">evaluate_only: int = None</em>, <em class="sig-param">use_cpu: bool = False</em>, <em class="sig-param">experiment_path='/tmp'</em>, <em class="sig-param">seed=None</em>, <em class="sig-param">checkpoint_save_secs=None</em>, <em class="sig-param">checkpoint_restore_dir=None</em>, <em class="sig-param">checkpoint_restore_path=None</em>, <em class="sig-param">checkpoint_save_dir=None</em>, <em class="sig-param">export_onnx_graph: bool = False</em>, <em class="sig-param">apply_stop_condition: bool = False</em>, <em class="sig-param">num_gpu: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#TaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.TaskParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -299,7 +300,7 @@ the dir to restore the checkpoints from</p></li>
<h2>DistributedTaskParameters<a class="headerlink" href="#distributedtaskparameters" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.base_parameters.DistributedTaskParameters">
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">DistributedTaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks</em>, <em>parameters_server_hosts: str</em>, <em>worker_hosts: str</em>, <em>job_type: str</em>, <em>task_index: int</em>, <em>evaluate_only: int = None</em>, <em>num_tasks: int = None</em>, <em>num_training_tasks: int = None</em>, <em>use_cpu: bool = False</em>, <em>experiment_path=None</em>, <em>dnd=None</em>, <em>shared_memory_scratchpad=None</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_path=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#DistributedTaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.DistributedTaskParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.base_parameters.</code><code class="sig-name descname">DistributedTaskParameters</code><span class="sig-paren">(</span><em class="sig-param">framework_type: rl_coach.base_parameters.Frameworks</em>, <em class="sig-param">parameters_server_hosts: str</em>, <em class="sig-param">worker_hosts: str</em>, <em class="sig-param">job_type: str</em>, <em class="sig-param">task_index: int</em>, <em class="sig-param">evaluate_only: int = None</em>, <em class="sig-param">num_tasks: int = None</em>, <em class="sig-param">num_training_tasks: int = None</em>, <em class="sig-param">use_cpu: bool = False</em>, <em class="sig-param">experiment_path=None</em>, <em class="sig-param">dnd=None</em>, <em class="sig-param">shared_memory_scratchpad=None</em>, <em class="sig-param">seed=None</em>, <em class="sig-param">checkpoint_save_secs=None</em>, <em class="sig-param">checkpoint_restore_path=None</em>, <em class="sig-param">checkpoint_save_dir=None</em>, <em class="sig-param">export_onnx_graph: bool = False</em>, <em class="sig-param">apply_stop_condition: bool = False</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#DistributedTaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.DistributedTaskParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
+2 -1
View File
@@ -124,6 +124,7 @@
<li class="toctree-l2"><a class="reference internal" href="cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
@@ -239,7 +240,7 @@ the expert for each state.</p>
</ol>
<dl class="class">
<dt id="rl_coach.agents.bc_agent.BCAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.bc_agent.</code><code class="descname">BCAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/bc_agent.html#BCAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.bc_agent.BCAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.bc_agent.</code><code class="sig-name descname">BCAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/bc_agent.html#BCAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.bc_agent.BCAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd></dd></dl>
</div>
+2 -1
View File
@@ -124,6 +124,7 @@
</li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
@@ -245,7 +246,7 @@ so that the loss for the other heads will be zeroed out.</p></li>
</ol>
<dl class="class">
<dt id="rl_coach.agents.cil_agent.CILAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.cil_agent.</code><code class="descname">CILAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/cil_agent.html#CILAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.cil_agent.CILAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.cil_agent.</code><code class="sig-name descname">CILAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/cil_agent.html#CILAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.cil_agent.CILAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>state_key_with_the_class_index</strong> (str)
+67 -36
View File
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="value_optimization/double_dqn.html">Double DQN</a></li>
@@ -225,6 +226,7 @@ A detailed description of those algorithms can be found by navigating to each of
<li class="toctree-l1"><a class="reference internal" href="imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l1"><a class="reference internal" href="other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l1"><a class="reference internal" href="value_optimization/double_dqn.html">Double DQN</a></li>
@@ -243,7 +245,7 @@ A detailed description of those algorithms can be found by navigating to each of
</div>
<dl class="class">
<dt id="rl_coach.base_parameters.AgentParameters">
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">AgentParameters</code><span class="sig-paren">(</span><em>algorithm: rl_coach.base_parameters.AlgorithmParameters, exploration: ExplorationParameters, memory: MemoryParameters, networks: Dict[str, rl_coach.base_parameters.NetworkParameters], visualization: rl_coach.base_parameters.VisualizationParameters = &lt;rl_coach.base_parameters.VisualizationParameters object&gt;</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#AgentParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.AgentParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.base_parameters.</code><code class="sig-name descname">AgentParameters</code><span class="sig-paren">(</span><em class="sig-param">algorithm: rl_coach.base_parameters.AlgorithmParameters, exploration: ExplorationParameters, memory: MemoryParameters, networks: Dict[str, rl_coach.base_parameters.NetworkParameters], visualization: rl_coach.base_parameters.VisualizationParameters = &lt;rl_coach.base_parameters.VisualizationParameters object&gt;</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#AgentParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.AgentParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -270,7 +272,7 @@ used for visualization purposes, such as printing to the screen, rendering, and
<dl class="class">
<dt id="rl_coach.agents.agent.Agent">
<em class="property">class </em><code class="descclassname">rl_coach.agents.agent.</code><code class="descname">Agent</code><span class="sig-paren">(</span><em>agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em>parent: Union[LevelManager</em>, <em>CompositeAgent] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.agent.</code><code class="sig-name descname">Agent</code><span class="sig-paren">(</span><em class="sig-param">agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em class="sig-param">parent: Union[LevelManager</em>, <em class="sig-param">CompositeAgent] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>agent_parameters</strong> A AgentParameters class instance with all the agent parameters</p>
@@ -278,7 +280,7 @@ used for visualization purposes, such as printing to the screen, rendering, and
</dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.act">
<code class="descname">act</code><span class="sig-paren">(</span><em>action: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List] = None</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.act"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.act" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">act</code><span class="sig-paren">(</span><em class="sig-param">action: Union[None</em>, <em class="sig-param">int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray</em>, <em class="sig-param">List] = None</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.act"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.act" title="Permalink to this definition"></a></dt>
<dd><p>Given the agents current knowledge, decide on the next action to apply to the environment</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -292,7 +294,7 @@ used for visualization purposes, such as printing to the screen, rendering, and
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.call_memory">
<code class="descname">call_memory</code><span class="sig-paren">(</span><em>func</em>, <em>args=()</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.call_memory"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.call_memory" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">call_memory</code><span class="sig-paren">(</span><em class="sig-param">func</em>, <em class="sig-param">args=()</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.call_memory"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.call_memory" title="Permalink to this definition"></a></dt>
<dd><p>This function is a wrapper to allow having the same calls for shared or unshared memories.
It should be used instead of calling the memory directly in order to allow different algorithms to work
both with a shared and a local memory.</p>
@@ -311,7 +313,7 @@ both with a shared and a local memory.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.choose_action">
<code class="descname">choose_action</code><span class="sig-paren">(</span><em>curr_state</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.choose_action"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.choose_action" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">choose_action</code><span class="sig-paren">(</span><em class="sig-param">curr_state</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.choose_action"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.choose_action" title="Permalink to this definition"></a></dt>
<dd><p>choose an action to act with in the current episode being played. Different behavior might be exhibited when
training or testing.</p>
<dl class="field-list simple">
@@ -326,7 +328,7 @@ training or testing.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.collect_savers">
<code class="descname">collect_savers</code><span class="sig-paren">(</span><em>parent_path_suffix: str</em><span class="sig-paren">)</span> &#x2192; rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.collect_savers" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">collect_savers</code><span class="sig-paren">(</span><em class="sig-param">parent_path_suffix: str</em><span class="sig-paren">)</span> &#x2192; rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.collect_savers" title="Permalink to this definition"></a></dt>
<dd><p>Collect all of agents network savers
:param parent_path_suffix: path suffix of the parent of the agent
(could be name of level manager or composite agent)
@@ -335,7 +337,7 @@ training or testing.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.create_networks">
<code class="descname">create_networks</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Dict[str, rl_coach.architectures.network_wrapper.NetworkWrapper]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.create_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.create_networks" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">create_networks</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Dict[str, rl_coach.architectures.network_wrapper.NetworkWrapper]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.create_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.create_networks" title="Permalink to this definition"></a></dt>
<dd><p>Create all the networks of the agent.
The network creation will be done after setting the environment parameters for the agent, since they are needed
for creating the network.</p>
@@ -346,9 +348,16 @@ for creating the network.</p>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.freeze_memory">
<code class="sig-name descname">freeze_memory</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.freeze_memory"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.freeze_memory" title="Permalink to this definition"></a></dt>
<dd><p>Shuffle episodes in the memory and freeze it to make sure that no extra data is being pushed anymore.
:return: None</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.get_predictions">
<code class="descname">get_predictions</code><span class="sig-paren">(</span><em>states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.get_predictions"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.get_predictions" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_predictions</code><span class="sig-paren">(</span><em class="sig-param">states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.get_predictions"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.get_predictions" title="Permalink to this definition"></a></dt>
<dd><p>Get a prediction from the agent with regard to the requested prediction_type.
If the agent cannot predict this type of prediction_type, or if there is more than possible way to do so,
raise a ValueException.</p>
@@ -367,7 +376,7 @@ raise a ValueException.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.get_state_embedding">
<code class="descname">get_state_embedding</code><span class="sig-paren">(</span><em>state: dict</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.get_state_embedding"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.get_state_embedding" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_state_embedding</code><span class="sig-paren">(</span><em class="sig-param">state: dict</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.get_state_embedding"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.get_state_embedding" title="Permalink to this definition"></a></dt>
<dd><p>Given a state, get the corresponding state embedding from the main network</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -381,7 +390,7 @@ raise a ValueException.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.handle_episode_ended">
<code class="descname">handle_episode_ended</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.handle_episode_ended"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.handle_episode_ended" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">handle_episode_ended</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.handle_episode_ended"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.handle_episode_ended" title="Permalink to this definition"></a></dt>
<dd><p>Make any changes needed when each episode is ended.
This includes incrementing counters, updating full episode dependent values, updating logs, etc.
This function is called right after each episode is ended.</p>
@@ -394,7 +403,7 @@ This function is called right after each episode is ended.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.init_environment_dependent_modules">
<code class="descname">init_environment_dependent_modules</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.init_environment_dependent_modules"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.init_environment_dependent_modules" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">init_environment_dependent_modules</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.init_environment_dependent_modules"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.init_environment_dependent_modules" title="Permalink to this definition"></a></dt>
<dd><p>Initialize any modules that depend on knowing information about the environment such as the action space or
the observation space</p>
<dl class="field-list simple">
@@ -404,9 +413,20 @@ the observation space</p>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.initialize_session_dependent_components">
<code class="sig-name descname">initialize_session_dependent_components</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.initialize_session_dependent_components"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.initialize_session_dependent_components" title="Permalink to this definition"></a></dt>
<dd><p>Initialize components which require a session as part of their initialization.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>None</p>
</dd>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.learn_from_batch">
<code class="descname">learn_from_batch</code><span class="sig-paren">(</span><em>batch</em><span class="sig-paren">)</span> &#x2192; Tuple[float, List, List]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.learn_from_batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.learn_from_batch" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">learn_from_batch</code><span class="sig-paren">(</span><em class="sig-param">batch</em><span class="sig-paren">)</span> &#x2192; Tuple[float, List, List]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.learn_from_batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.learn_from_batch" title="Permalink to this definition"></a></dt>
<dd><p>Given a batch of transitions, calculates their target values and updates the network.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -418,9 +438,20 @@ the observation space</p>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.load_memory_from_file">
<code class="sig-name descname">load_memory_from_file</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.load_memory_from_file"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.load_memory_from_file" title="Permalink to this definition"></a></dt>
<dd><p>Load memory transitions from a file.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>None</p>
</dd>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.log_to_screen">
<code class="descname">log_to_screen</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.log_to_screen"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.log_to_screen" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">log_to_screen</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.log_to_screen"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.log_to_screen" title="Permalink to this definition"></a></dt>
<dd><p>Write an episode summary line to the terminal</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -431,7 +462,7 @@ the observation space</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.observe">
<code class="descname">observe</code><span class="sig-paren">(</span><em>env_response: rl_coach.core_types.EnvResponse</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.observe"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.observe" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">observe</code><span class="sig-paren">(</span><em class="sig-param">env_response: rl_coach.core_types.EnvResponse</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.observe"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.observe" title="Permalink to this definition"></a></dt>
<dd><p>Given a response from the environment, distill the observation from it and store it for later use.
The response should be a dictionary containing the performed action, the new observation and measurements,
the reward, a game over flag and any additional information necessary.</p>
@@ -446,9 +477,9 @@ given observation</p>
</dl>
</dd></dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.parent">
<code class="descname">parent</code><a class="headerlink" href="#rl_coach.agents.agent.Agent.parent" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">parent</code><a class="headerlink" href="#rl_coach.agents.agent.Agent.parent" title="Permalink to this definition"></a></dt>
<dd><p>Get the parent class of the agent</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -457,9 +488,9 @@ given observation</p>
</dl>
</dd></dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.phase">
<code class="descname">phase</code><a class="headerlink" href="#rl_coach.agents.agent.Agent.phase" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">phase</code><a class="headerlink" href="#rl_coach.agents.agent.Agent.phase" title="Permalink to this definition"></a></dt>
<dd><p>The current running phase of the agent</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -470,7 +501,7 @@ given observation</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.post_training_commands">
<code class="descname">post_training_commands</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.post_training_commands"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.post_training_commands" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">post_training_commands</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.post_training_commands"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.post_training_commands" title="Permalink to this definition"></a></dt>
<dd><p>A function which allows adding any functionality that is required to run right after the training phase ends.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -481,7 +512,7 @@ given observation</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.prepare_batch_for_inference">
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.array]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.prepare_batch_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em class="sig-param">states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.core.multiarray.array]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.prepare_batch_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.</p>
<dl class="field-list simple">
@@ -501,7 +532,7 @@ the observation relevant for the network from the states.</p></li>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.register_signal">
<code class="descname">register_signal</code><span class="sig-paren">(</span><em>signal_name: str</em>, <em>dump_one_value_per_episode: bool = True</em>, <em>dump_one_value_per_step: bool = False</em><span class="sig-paren">)</span> &#x2192; rl_coach.utils.Signal<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.register_signal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.register_signal" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">register_signal</code><span class="sig-paren">(</span><em class="sig-param">signal_name: str</em>, <em class="sig-param">dump_one_value_per_episode: bool = True</em>, <em class="sig-param">dump_one_value_per_step: bool = False</em><span class="sig-paren">)</span> &#x2192; rl_coach.utils.Signal<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.register_signal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.register_signal" title="Permalink to this definition"></a></dt>
<dd><p>Register a signal such that its statistics will be dumped and be viewable through dashboard</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -519,7 +550,7 @@ the observation relevant for the network from the states.</p></li>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.reset_evaluation_state">
<code class="descname">reset_evaluation_state</code><span class="sig-paren">(</span><em>val: rl_coach.core_types.RunPhase</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.reset_evaluation_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.reset_evaluation_state" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">reset_evaluation_state</code><span class="sig-paren">(</span><em class="sig-param">val: rl_coach.core_types.RunPhase</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.reset_evaluation_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.reset_evaluation_state" title="Permalink to this definition"></a></dt>
<dd><p>Perform accumulators initialization when entering an evaluation phase, and signal dumping when exiting an
evaluation phase. Entering or exiting the evaluation phase is determined according to the new phase given
by val, and by the current phase set in self.phase.</p>
@@ -535,7 +566,7 @@ by val, and by the current phase set in self.phase.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.reset_internal_state">
<code class="descname">reset_internal_state</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.reset_internal_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.reset_internal_state" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">reset_internal_state</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.reset_internal_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.reset_internal_state" title="Permalink to this definition"></a></dt>
<dd><p>Reset all the episodic parameters. This function is called right before each episode starts.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -546,7 +577,7 @@ by val, and by the current phase set in self.phase.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.restore_checkpoint">
<code class="descname">restore_checkpoint</code><span class="sig-paren">(</span><em>checkpoint_dir: str</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.restore_checkpoint"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.restore_checkpoint" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">restore_checkpoint</code><span class="sig-paren">(</span><em class="sig-param">checkpoint_dir: str</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.restore_checkpoint"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.restore_checkpoint" title="Permalink to this definition"></a></dt>
<dd><p>Allows agents to store additional information when saving checkpoints.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -560,7 +591,7 @@ by val, and by the current phase set in self.phase.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.run_off_policy_evaluation">
<code class="descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.agent.Agent.run_off_policy_evaluation" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.agent.Agent.run_off_policy_evaluation" title="Permalink to this definition"></a></dt>
<dd><p>Run off-policy evaluation estimators to evaluate the trained policy performance against a dataset.
Should only be implemented for off-policy RL algorithms.</p>
<dl class="field-list simple">
@@ -572,7 +603,7 @@ Should only be implemented for off-policy RL algorithms.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference">
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.run_pre_network_filter_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em class="sig-param">state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.run_pre_network_filter_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Run filters which where defined for being applied right before using the state for inference.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -589,7 +620,7 @@ Should only be implemented for off-policy RL algorithms.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.save_checkpoint">
<code class="descname">save_checkpoint</code><span class="sig-paren">(</span><em>checkpoint_prefix: str</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.save_checkpoint"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.save_checkpoint" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">save_checkpoint</code><span class="sig-paren">(</span><em class="sig-param">checkpoint_prefix: str</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.save_checkpoint"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.save_checkpoint" title="Permalink to this definition"></a></dt>
<dd><p>Allows agents to store additional information when saving checkpoints.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -603,7 +634,7 @@ Should only be implemented for off-policy RL algorithms.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.set_environment_parameters">
<code class="descname">set_environment_parameters</code><span class="sig-paren">(</span><em>spaces: rl_coach.spaces.SpacesDefinition</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.set_environment_parameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.set_environment_parameters" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_environment_parameters</code><span class="sig-paren">(</span><em class="sig-param">spaces: rl_coach.spaces.SpacesDefinition</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.set_environment_parameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.set_environment_parameters" title="Permalink to this definition"></a></dt>
<dd><p>Sets the parameters that are environment dependent. As a side effect, initializes all the components that are
dependent on those values, by calling init_environment_dependent_modules</p>
<dl class="field-list simple">
@@ -618,7 +649,7 @@ dependent on those values, by calling init_environment_dependent_modules</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.set_incoming_directive">
<code class="descname">set_incoming_directive</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.set_incoming_directive"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.set_incoming_directive" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_incoming_directive</code><span class="sig-paren">(</span><em class="sig-param">action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.set_incoming_directive"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.set_incoming_directive" title="Permalink to this definition"></a></dt>
<dd><p>Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent
has another master agent that is controlling it. In such cases, the master agent can define the goals for the
slave agent, define its observation, possible actions, etc. The directive type is defined by the agent
@@ -635,7 +666,7 @@ in-action-space.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.set_session">
<code class="descname">set_session</code><span class="sig-paren">(</span><em>sess</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.set_session"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.set_session" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_session</code><span class="sig-paren">(</span><em class="sig-param">sess</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.set_session"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.set_session" title="Permalink to this definition"></a></dt>
<dd><p>Set the deep learning framework session for all the agents in the composite agent</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -646,7 +677,7 @@ in-action-space.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.setup_logger">
<code class="descname">setup_logger</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.setup_logger"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.setup_logger" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">setup_logger</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.setup_logger"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.setup_logger" title="Permalink to this definition"></a></dt>
<dd><p>Setup the logger for the agent</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -657,7 +688,7 @@ in-action-space.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.sync">
<code class="descname">sync</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.sync"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.sync" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">sync</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.sync"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.sync" title="Permalink to this definition"></a></dt>
<dd><p>Sync the global network parameters to local networks</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -668,7 +699,7 @@ in-action-space.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.train">
<code class="descname">train</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; float<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.train"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.train" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">train</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; float<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.train"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.train" title="Permalink to this definition"></a></dt>
<dd><p>Check if a training phase should be done as configured by num_consecutive_playing_steps.
If it should, then do several training steps as configured by num_consecutive_training_steps.
A single training iteration: Sample a batch, train on it and update target networks.</p>
@@ -681,7 +712,7 @@ A single training iteration: Sample a batch, train on it and update target netwo
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.update_log">
<code class="descname">update_log</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.update_log"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.update_log" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">update_log</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.update_log"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.update_log" title="Permalink to this definition"></a></dt>
<dd><p>Updates the episodic log file with all the signal values from the most recent episode.
Additional signals for logging can be set by the creating a new signal using self.register_signal,
and then updating it with some internal agent values.</p>
@@ -694,7 +725,7 @@ and then updating it with some internal agent values.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.update_step_in_episode_log">
<code class="descname">update_step_in_episode_log</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.update_step_in_episode_log"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.update_step_in_episode_log" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">update_step_in_episode_log</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.update_step_in_episode_log"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.update_step_in_episode_log" title="Permalink to this definition"></a></dt>
<dd><p>Updates the in-episode log file with all the signal values from the most recent step.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -705,7 +736,7 @@ and then updating it with some internal agent values.</p>
<dl class="method">
<dt id="rl_coach.agents.agent.Agent.update_transition_before_adding_to_replay_buffer">
<code class="descname">update_transition_before_adding_to_replay_buffer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.update_transition_before_adding_to_replay_buffer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.update_transition_before_adding_to_replay_buffer" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">update_transition_before_adding_to_replay_buffer</code><span class="sig-paren">(</span><em class="sig-param">transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.update_transition_before_adding_to_replay_buffer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.update_transition_before_adding_to_replay_buffer" title="Permalink to this definition"></a></dt>
<dd><p>Allows agents to update the transition just before adding it to the replay buffer.
Can be useful for agents that want to tweak the reward, termination signal, etc.</p>
<dl class="field-list simple">
+2 -1
View File
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Direct Future Prediction</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
@@ -249,7 +250,7 @@ measurements that were seen in time-steps <span class="math notranslate nohighli
For the actions that were not taken, the targets are the current values.</p>
<dl class="class">
<dt id="rl_coach.agents.dfp_agent.DFPAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.dfp_agent.</code><code class="descname">DFPAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/dfp_agent.html#DFPAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dfp_agent.DFPAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.dfp_agent.</code><code class="sig-name descname">DFPAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/dfp_agent.html#DFPAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dfp_agent.DFPAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -125,6 +125,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
@@ -247,7 +248,7 @@ where <span class="math notranslate nohighlight">\(k\)</span> is <span class="ma
<span class="math notranslate nohighlight">\(L = -\mathop{\mathbb{E}} [log (\pi) \cdot A]\)</span></p>
<dl class="class">
<dt id="rl_coach.agents.actor_critic_agent.ActorCriticAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.actor_critic_agent.</code><code class="descname">ActorCriticAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/actor_critic_agent.html#ActorCriticAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.actor_critic_agent.ActorCriticAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.actor_critic_agent.</code><code class="sig-name descname">ActorCriticAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/actor_critic_agent.html#ActorCriticAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.actor_critic_agent.ActorCriticAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -125,6 +125,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
@@ -279,7 +280,7 @@ The goal of the trust region update is to the difference between the updated pol
</ol>
<dl class="class">
<dt id="rl_coach.agents.acer_agent.ACERAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.acer_agent.</code><code class="descname">ACERAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/acer_agent.html#ACERAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.acer_agent.ACERAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.acer_agent.</code><code class="sig-name descname">ACERAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/acer_agent.html#ACERAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.acer_agent.ACERAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -125,6 +125,7 @@
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
@@ -252,7 +253,7 @@ clipped surrogate loss:</p>
</ol>
<dl class="class">
<dt id="rl_coach.agents.clipped_ppo_agent.ClippedPPOAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.clipped_ppo_agent.</code><code class="descname">ClippedPPOAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/clipped_ppo_agent.html#ClippedPPOAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.clipped_ppo_agent.ClippedPPOAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.clipped_ppo_agent.</code><code class="sig-name descname">ClippedPPOAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/clipped_ppo_agent.html#ClippedPPOAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.clipped_ppo_agent.ClippedPPOAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -37,7 +37,7 @@
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link rel="next" title="Soft Actor-Critic" href="sac.html" />
<link rel="next" title="Twin Delayed Deep Deterministic Policy Gradient" href="td3.html" />
<link rel="prev" title="Clipped Proximal Policy Optimization" href="cppo.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
@@ -125,6 +125,7 @@
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
@@ -257,7 +258,7 @@ given <span class="math notranslate nohighlight">\(\nabla_a Q(s,a)\)</span>. Fin
<p>After every training step, do a soft update of the critic and actor target networks weights from the online networks.</p>
<dl class="class">
<dt id="rl_coach.agents.ddpg_agent.DDPGAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.ddpg_agent.</code><code class="descname">DDPGAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/ddpg_agent.html#DDPGAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.ddpg_agent.DDPGAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.ddpg_agent.</code><code class="sig-name descname">DDPGAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/ddpg_agent.html#DDPGAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.ddpg_agent.DDPGAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -297,7 +298,7 @@ values. If set to False, the terminal states reward will be taken as the target
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="sac.html" class="btn btn-neutral float-right" title="Soft Actor-Critic" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="td3.html" class="btn btn-neutral float-right" title="Twin Delayed Deep Deterministic Policy Gradient" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="cppo.html" class="btn btn-neutral float-left" title="Clipped Proximal Policy Optimization" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
@@ -251,7 +252,7 @@ serves the same purpose - reducing the update variance. After accumulating gradi
the gradients are then applied to the network.</p>
<dl class="class">
<dt id="rl_coach.agents.policy_gradients_agent.PolicyGradientAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.policy_gradients_agent.</code><code class="descname">PolicyGradientAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/policy_gradients_agent.html#PolicyGradientAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.policy_gradients_agent.PolicyGradientAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.policy_gradients_agent.</code><code class="sig-name descname">PolicyGradientAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/policy_gradients_agent.html#PolicyGradientAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.policy_gradients_agent.PolicyGradientAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
@@ -253,7 +254,7 @@ increase the penalty, if it went too low, reduce it. Otherwise, leave it unchang
</ol>
<dl class="class">
<dt id="rl_coach.agents.ppo_agent.PPOAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.ppo_agent.</code><code class="descname">PPOAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/ppo_agent.html#PPOAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.ppo_agent.PPOAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.ppo_agent.</code><code class="sig-name descname">PPOAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/ppo_agent.html#PPOAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.ppo_agent.PPOAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -38,7 +38,7 @@
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link rel="next" title="Direct Future Prediction" href="../other/dfp.html" />
<link rel="prev" title="Deep Deterministic Policy Gradient" href="ddpg.html" />
<link rel="prev" title="Twin Delayed Deep Deterministic Policy Gradient" href="td3.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
</head>
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Soft Actor-Critic</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
<li class="toctree-l3"><a class="reference internal" href="#algorithm-description">Algorithm Description</a><ul>
@@ -258,7 +259,7 @@ from the current policy.</p>
<p>After every training step, do a soft update of the V target networks weights from the online networks.</p>
<dl class="class">
<dt id="rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.soft_actor_critic_agent.</code><code class="descname">SoftActorCriticAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/soft_actor_critic_agent.html#SoftActorCriticAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.soft_actor_critic_agent.</code><code class="sig-name descname">SoftActorCriticAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/soft_actor_critic_agent.html#SoftActorCriticAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -290,7 +291,7 @@ and not sampled from the policy distribution.</p></li>
<a href="../other/dfp.html" class="btn btn-neutral float-right" title="Direct Future Prediction" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="ddpg.html" class="btn btn-neutral float-left" title="Deep Deterministic Policy Gradient" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
<a href="td3.html" class="btn btn-neutral float-left" title="Twin Delayed Deep Deterministic Policy Gradient" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
@@ -0,0 +1,347 @@
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Twin Delayed Deep Deterministic Policy Gradient &mdash; Reinforcement Learning Coach 0.12.0 documentation</title>
<script type="text/javascript" src="../../../_static/js/modernizr.min.js"></script>
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script type="text/javascript" src="../../../_static/jquery.js"></script>
<script type="text/javascript" src="../../../_static/underscore.js"></script>
<script type="text/javascript" src="../../../_static/doctools.js"></script>
<script type="text/javascript" src="../../../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<link rel="next" title="Soft Actor-Critic" href="sac.html" />
<link rel="prev" title="Deep Deterministic Policy Gradient" href="ddpg.html" />
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Intro</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
</ul>
<p class="caption"><span class="caption-text">Design</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
</ul>
<p class="caption"><span class="caption-text">Contributing</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
</ul>
<p class="caption"><span class="caption-text">Components</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Agents</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="ac.html">Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="acer.html">ACER</a></li>
<li class="toctree-l2"><a class="reference internal" href="../imitation/bc.html">Behavioral Cloning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/bs_dqn.html">Bootstrapped DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/categorical_dqn.html">Categorical DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Twin Delayed Deep Deterministic Policy Gradient</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
<li class="toctree-l3"><a class="reference internal" href="#algorithm-description">Algorithm Description</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#choosing-an-action">Choosing an action</a></li>
<li class="toctree-l4"><a class="reference internal" href="#training-the-network">Training the network</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dueling_dqn.html">Dueling DQN</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/mmc.html">Mixed Monte Carlo</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/n_step.html">N-Step Q Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/naf.html">Normalized Advantage Functions</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/nec.html">Neural Episodic Control</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/pal.html">Persistent Advantage Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="pg.html">Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="ppo.html">Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/rainbow.html">Rainbow</a></li>
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/qr_dqn.html">Quantile Regression DQN</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../architectures/index.html">Architectures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../data_stores/index.html">Data Stores</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../environments/index.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../exploration_policies/index.html">Exploration Policies</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../filters/index.html">Filters</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../memories/index.html">Memories</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../memory_backends/index.html">Memory Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../orchestrators/index.html">Orchestrators</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../core_types.html">Core Types</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../spaces.html">Spaces</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../additional_parameters.html">Additional Parameters</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../../index.html">Reinforcement Learning Coach</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../../index.html">Docs</a> &raquo;</li>
<li><a href="../index.html">Agents</a> &raquo;</li>
<li>Twin Delayed Deep Deterministic Policy Gradient</li>
<li class="wy-breadcrumbs-aside">
<a href="../../../_sources/components/agents/policy_optimization/td3.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="section" id="twin-delayed-deep-deterministic-policy-gradient">
<h1>Twin Delayed Deep Deterministic Policy Gradient<a class="headerlink" href="#twin-delayed-deep-deterministic-policy-gradient" title="Permalink to this headline"></a></h1>
<p><strong>Actions space:</strong> Continuous</p>
<p><strong>References:</strong> <a class="reference external" href="https://arxiv.org/pdf/1802.09477">Addressing Function Approximation Error in Actor-Critic Methods</a></p>
<div class="section" id="network-structure">
<h2>Network Structure<a class="headerlink" href="#network-structure" title="Permalink to this headline"></a></h2>
<img alt="../../../_images/td3.png" class="align-center" src="../../../_images/td3.png" />
</div>
<div class="section" id="algorithm-description">
<h2>Algorithm Description<a class="headerlink" href="#algorithm-description" title="Permalink to this headline"></a></h2>
<div class="section" id="choosing-an-action">
<h3>Choosing an action<a class="headerlink" href="#choosing-an-action" title="Permalink to this headline"></a></h3>
<p>Pass the current states through the actor network, and get an action mean vector <span class="math notranslate nohighlight">\(\mu\)</span>.
While in training phase, use a continuous exploration policy, such as a small zero-meaned gaussian noise,
to add exploration noise to the action. When testing, use the mean vector <span class="math notranslate nohighlight">\(\mu\)</span> as-is.</p>
</div>
<div class="section" id="training-the-network">
<h3>Training the network<a class="headerlink" href="#training-the-network" title="Permalink to this headline"></a></h3>
<p>Start by sampling a batch of transitions from the experience replay.</p>
<ul>
<li><p>To train the two <strong>critic networks</strong>, use the following targets:</p>
<p><span class="math notranslate nohighlight">\(y_t=r(s_t,a_t )+\gamma \cdot \min_{i=1,2} Q_{i}(s_{t+1},\mu(s_{t+1} )+[\mathcal{N}(0,\,\sigma^{2})]^{MAX\_NOISE}_{MIN\_NOISE})\)</span></p>
<p>First run the actor target network, using the next states as the inputs, and get <span class="math notranslate nohighlight">\(\mu (s_{t+1} )\)</span>. Then, add a
clipped gaussian noise to these actions, and clip the resulting actions to the actions space.
Next, run the critic target networks using the next states and <span class="math notranslate nohighlight">\(\mu (s_{t+1} )+[\mathcal{N}(0,\,\sigma^{2})]^{MAX\_NOISE}_{MIN\_NOISE}\)</span>,
and use the minimum between the two critic networks predictions in order to calculate <span class="math notranslate nohighlight">\(y_t\)</span> according to the
equation above. To train the networks, use the current states and actions as the inputs, and <span class="math notranslate nohighlight">\(y_t\)</span>
as the targets.</p>
</li>
<li><p>To train the <strong>actor network</strong>, use the following equation:</p>
<p><span class="math notranslate nohighlight">\(\nabla_{\theta^\mu } J \approx E_{s_t \tilde{} \rho^\beta } [\nabla_a Q_{1}(s,a)|_{s=s_t,a=\mu (s_t ) } \cdot \nabla_{\theta^\mu} \mu(s)|_{s=s_t} ]\)</span></p>
<p>Use the actors online network to get the action mean values using the current states as the inputs.
Then, use the first critics online network in order to get the gradients of the critic output with respect to the
action mean values <span class="math notranslate nohighlight">\(\nabla _a Q_{1}(s,a)|_{s=s_t,a=\mu(s_t ) }\)</span>.
Using the chain rule, calculate the gradients of the actors output, with respect to the actor weights,
given <span class="math notranslate nohighlight">\(\nabla_a Q(s,a)\)</span>. Finally, apply those gradients to the actor network.</p>
<p>The actors training is done at a slower frequency than the critics training, in order to allow the critic to better fit the
current policy, before exercising the critic in order to train the actor.
Following the same, delayed, actors training cadence, do a soft update of the critic and actor target networks weights
from the online networks.</p>
</li>
</ul>
<dl class="class">
<dt id="rl_coach.agents.td3_agent.TD3AlgorithmParameters">
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.td3_agent.</code><code class="sig-name descname">TD3AlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/td3_agent.html#TD3AlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.td3_agent.TD3AlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>num_steps_between_copying_online_weights_to_target</strong> (StepMethod)
The number of steps between copying the online network weights to the target network weights.</p></li>
<li><p><strong>rate_for_copying_weights_to_target</strong> (float)
When copying the online network weights to the target network weights, a soft update will be used, which
weight the new online network weights by rate_for_copying_weights_to_target</p></li>
<li><p><strong>num_consecutive_playing_steps</strong> (StepMethod)
The number of consecutive steps to act between every two training iterations</p></li>
<li><p><strong>use_target_network_for_evaluation</strong> (bool)
If set to True, the target network will be used for predicting the actions when choosing actions to act.
Since the target network weights change more slowly, the predicted actions will be more consistent.</p></li>
<li><p><strong>action_penalty</strong> (float)
The amount by which to penalize the network on high action feature (pre-activation) values.
This can prevent the actions features from saturating the TanH activation function, and therefore prevent the
gradients from becoming very low.</p></li>
<li><p><strong>clip_critic_targets</strong> (Tuple[float, float] or None)
The range to clip the critic target to in order to prevent overestimation of the action values.</p></li>
<li><p><strong>use_non_zero_discount_for_terminal_states</strong> (bool)
If set to True, the discount factor will be used for terminal states to bootstrap the next predicted state
values. If set to False, the terminal states reward will be taken as the target return for the network.</p></li>
</ul>
</dd>
</dl>
</dd></dl>
</div>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="sac.html" class="btn btn-neutral float-right" title="Soft Actor-Critic" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="ddpg.html" class="btn btn-neutral float-left" title="Deep Deterministic Policy Gradient" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2018-2019, Intel AI Lab
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>
@@ -126,6 +126,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -124,6 +124,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -244,7 +245,7 @@ probability distribution. Only the target of the actions that were actually ta
</ol>
<dl class="class">
<dt id="rl_coach.agents.categorical_dqn_agent.CategoricalDQNAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.categorical_dqn_agent.</code><code class="descname">CategoricalDQNAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/categorical_dqn_agent.html#CategoricalDQNAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.categorical_dqn_agent.CategoricalDQNAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.categorical_dqn_agent.</code><code class="sig-name descname">CategoricalDQNAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/categorical_dqn_agent.html#CategoricalDQNAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.categorical_dqn_agent.CategoricalDQNAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Double DQN</a><ul>
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -243,7 +244,7 @@ Set those values as the targets for the actions that were not actually played.</
</ol>
<dl class="class">
<dt id="rl_coach.agents.dqn_agent.DQNAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.dqn_agent.</code><code class="descname">DQNAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/dqn_agent.html#DQNAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.dqn_agent.</code><code class="sig-name descname">DQNAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/dqn_agent.html#DQNAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd></dd></dl>
</div>
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -240,7 +241,7 @@
Once in every few thousand steps, copy the weights from the online network to the target network.</p>
<dl class="class">
<dt id="rl_coach.agents.mmc_agent.MixedMonteCarloAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.mmc_agent.</code><code class="descname">MixedMonteCarloAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/mmc_agent.html#MixedMonteCarloAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.mmc_agent.MixedMonteCarloAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.mmc_agent.</code><code class="sig-name descname">MixedMonteCarloAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/mmc_agent.html#MixedMonteCarloAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.mmc_agent.MixedMonteCarloAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>monte_carlo_mixing_rate</strong> (float)
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -242,7 +243,7 @@ where <span class="math notranslate nohighlight">\(k\)</span> is <span class="ma
</ol>
<dl class="class">
<dt id="rl_coach.agents.n_step_q_agent.NStepQAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.n_step_q_agent.</code><code class="descname">NStepQAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/n_step_q_agent.html#NStepQAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.n_step_q_agent.NStepQAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.n_step_q_agent.</code><code class="sig-name descname">NStepQAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/n_step_q_agent.html#NStepQAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.n_step_q_agent.NStepQAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -243,7 +244,7 @@ and <span class="math notranslate nohighlight">\(y_t\)</span> as the targets.
After every training step, use a soft update in order to copy the weights from the online network to the target network.</p>
<dl class="class">
<dt id="rl_coach.agents.naf_agent.NAFAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.naf_agent.</code><code class="descname">NAFAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/naf_agent.html#NAFAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.naf_agent.NAFAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.naf_agent.</code><code class="sig-name descname">NAFAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/naf_agent.html#NAFAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.naf_agent.NAFAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd></dd></dl>
</div>
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -258,7 +259,7 @@ the network if necessary:
<span class="math notranslate nohighlight">\(y_t=\sum_{j=0}^{N-1}\gamma^j r(s_{t+j},a_{t+j} ) +\gamma^N max_a Q(s_{t+N},a)\)</span></p>
<dl class="class">
<dt id="rl_coach.agents.nec_agent.NECAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.nec_agent.</code><code class="descname">NECAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/nec_agent.html#NECAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.nec_agent.NECAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.nec_agent.</code><code class="sig-name descname">NECAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/nec_agent.html#NECAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.nec_agent.NECAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -251,7 +252,7 @@ has the highest predicted <span class="math notranslate nohighlight">\(Q\)</span
</ol>
<dl class="class">
<dt id="rl_coach.agents.pal_agent.PALAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.pal_agent.</code><code class="descname">PALAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/pal_agent.html#PALAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.pal_agent.PALAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.pal_agent.</code><code class="sig-name descname">PALAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/pal_agent.html#PALAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.pal_agent.PALAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -241,7 +242,7 @@ quantile locations. Only the targets of the actions that were actually taken are
</ol>
<dl class="class">
<dt id="rl_coach.agents.qr_dqn_agent.QuantileRegressionDQNAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.qr_dqn_agent.</code><code class="descname">QuantileRegressionDQNAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/qr_dqn_agent.html#QuantileRegressionDQNAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.qr_dqn_agent.QuantileRegressionDQNAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.qr_dqn_agent.</code><code class="sig-name descname">QuantileRegressionDQNAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/qr_dqn_agent.html#QuantileRegressionDQNAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.qr_dqn_agent.QuantileRegressionDQNAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -117,6 +117,7 @@
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/td3.html">Twin Delayed Deep Deterministic Policy Gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
@@ -256,7 +257,7 @@ using the KL divergence loss that is returned from the network.</p></li>
</ol>
<dl class="class">
<dt id="rl_coach.agents.rainbow_dqn_agent.RainbowDQNAlgorithmParameters">
<em class="property">class </em><code class="descclassname">rl_coach.agents.rainbow_dqn_agent.</code><code class="descname">RainbowDQNAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/rainbow_dqn_agent.html#RainbowDQNAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.rainbow_dqn_agent.RainbowDQNAlgorithmParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.rainbow_dqn_agent.</code><code class="sig-name descname">RainbowDQNAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/rainbow_dqn_agent.html#RainbowDQNAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.rainbow_dqn_agent.RainbowDQNAlgorithmParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
+28 -28
View File
@@ -196,7 +196,7 @@ own components under a dedicated directory. For example, tensorflow components w
parts that are implemented using TensorFlow.</p>
<dl class="class">
<dt id="rl_coach.base_parameters.NetworkParameters">
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">NetworkParameters</code><span class="sig-paren">(</span><em>force_cpu=False</em>, <em>async_training=False</em>, <em>shared_optimizer=True</em>, <em>scale_down_gradients_by_number_of_workers_for_sync_training=True</em>, <em>clip_gradients=None</em>, <em>gradients_clipping_method=&lt;GradientClippingMethod.ClipByGlobalNorm: 0&gt;</em>, <em>l2_regularization=0</em>, <em>learning_rate=0.00025</em>, <em>learning_rate_decay_rate=0</em>, <em>learning_rate_decay_steps=0</em>, <em>input_embedders_parameters={}</em>, <em>embedding_merger_type=&lt;EmbeddingMergerType.Concat: 0&gt;</em>, <em>middleware_parameters=None</em>, <em>heads_parameters=[]</em>, <em>use_separate_networks_per_head=False</em>, <em>optimizer_type='Adam'</em>, <em>optimizer_epsilon=0.0001</em>, <em>adam_optimizer_beta1=0.9</em>, <em>adam_optimizer_beta2=0.99</em>, <em>rms_prop_optimizer_decay=0.9</em>, <em>batch_size=32</em>, <em>replace_mse_with_huber_loss=False</em>, <em>create_target_network=False</em>, <em>tensorflow_support=True</em>, <em>softmax_temperature=1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#NetworkParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.NetworkParameters" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.base_parameters.</code><code class="sig-name descname">NetworkParameters</code><span class="sig-paren">(</span><em class="sig-param">force_cpu=False</em>, <em class="sig-param">async_training=False</em>, <em class="sig-param">shared_optimizer=True</em>, <em class="sig-param">scale_down_gradients_by_number_of_workers_for_sync_training=True</em>, <em class="sig-param">clip_gradients=None</em>, <em class="sig-param">gradients_clipping_method=&lt;GradientClippingMethod.ClipByGlobalNorm: 0&gt;</em>, <em class="sig-param">l2_regularization=0</em>, <em class="sig-param">learning_rate=0.00025</em>, <em class="sig-param">learning_rate_decay_rate=0</em>, <em class="sig-param">learning_rate_decay_steps=0</em>, <em class="sig-param">input_embedders_parameters={}</em>, <em class="sig-param">embedding_merger_type=&lt;EmbeddingMergerType.Concat: 0&gt;</em>, <em class="sig-param">middleware_parameters=None</em>, <em class="sig-param">heads_parameters=[]</em>, <em class="sig-param">use_separate_networks_per_head=False</em>, <em class="sig-param">optimizer_type='Adam'</em>, <em class="sig-param">optimizer_epsilon=0.0001</em>, <em class="sig-param">adam_optimizer_beta1=0.9</em>, <em class="sig-param">adam_optimizer_beta2=0.99</em>, <em class="sig-param">rms_prop_optimizer_decay=0.9</em>, <em class="sig-param">batch_size=32</em>, <em class="sig-param">replace_mse_with_huber_loss=False</em>, <em class="sig-param">create_target_network=False</em>, <em class="sig-param">tensorflow_support=True</em>, <em class="sig-param">softmax_temperature=1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#NetworkParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.NetworkParameters" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -268,7 +268,7 @@ online network at will.</p></li>
<h2>Architecture<a class="headerlink" href="#architecture" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.architectures.architecture.Architecture">
<em class="property">class </em><code class="descclassname">rl_coach.architectures.architecture.</code><code class="descname">Architecture</code><span class="sig-paren">(</span><em>agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em>spaces: rl_coach.spaces.SpacesDefinition</em>, <em>name: str = ''</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.architectures.architecture.</code><code class="sig-name descname">Architecture</code><span class="sig-paren">(</span><em class="sig-param">agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em class="sig-param">spaces: rl_coach.spaces.SpacesDefinition</em>, <em class="sig-param">name: str = ''</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture" title="Permalink to this definition"></a></dt>
<dd><p>Creates a neural network architecture, that can be trained and used for inference.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -281,7 +281,7 @@ online network at will.</p></li>
</dl>
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.accumulate_gradients">
<code class="descname">accumulate_gradients</code><span class="sig-paren">(</span><em>inputs: Dict[str, numpy.ndarray], targets: List[numpy.ndarray], additional_fetches: list = None, importance_weights: numpy.ndarray = None, no_accumulation: bool = False</em><span class="sig-paren">)</span> &#x2192; Tuple[float, List[float], float, list]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.accumulate_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.accumulate_gradients" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">accumulate_gradients</code><span class="sig-paren">(</span><em class="sig-param">inputs: Dict[str, numpy.ndarray], targets: List[numpy.ndarray], additional_fetches: list = None, importance_weights: numpy.ndarray = None, no_accumulation: bool = False</em><span class="sig-paren">)</span> &#x2192; Tuple[float, List[float], float, list]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.accumulate_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.accumulate_gradients" title="Permalink to this definition"></a></dt>
<dd><p>Given a batch of inputs (i.e. states) and targets (e.g. discounted rewards), computes and accumulates the
gradients for model parameters. Will run forward and backward pass to compute gradients, clip the gradient
values if required and then accumulate gradients from all learners. It does not update the model weights,
@@ -324,7 +324,7 @@ fetched_tensors: all values for additional_fetches</p>
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.apply_and_reset_gradients">
<code class="descname">apply_and_reset_gradients</code><span class="sig-paren">(</span><em>gradients: List[numpy.ndarray], scaler: float = 1.0</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.apply_and_reset_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.apply_and_reset_gradients" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">apply_and_reset_gradients</code><span class="sig-paren">(</span><em class="sig-param">gradients: List[numpy.ndarray], scaler: float = 1.0</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.apply_and_reset_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.apply_and_reset_gradients" title="Permalink to this definition"></a></dt>
<dd><p>Applies the given gradients to the network weights and resets the gradient accumulations.
Has the same impact as calling <cite>apply_gradients</cite>, then <cite>reset_accumulated_gradients</cite>.</p>
<dl class="field-list simple">
@@ -340,7 +340,7 @@ of an identical network (either self or another identical network)</p></li>
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.apply_gradients">
<code class="descname">apply_gradients</code><span class="sig-paren">(</span><em>gradients: List[numpy.ndarray], scaler: float = 1.0</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.apply_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.apply_gradients" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">apply_gradients</code><span class="sig-paren">(</span><em class="sig-param">gradients: List[numpy.ndarray], scaler: float = 1.0</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.apply_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.apply_gradients" title="Permalink to this definition"></a></dt>
<dd><p>Applies the given gradients to the network weights.
Will be performed sync or async depending on <cite>network_parameters.async_training</cite></p>
<dl class="field-list simple">
@@ -356,7 +356,7 @@ of an identical network (either self or another identical network)</p></li>
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.collect_savers">
<code class="descname">collect_savers</code><span class="sig-paren">(</span><em>parent_path_suffix: str</em><span class="sig-paren">)</span> &#x2192; rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.collect_savers" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">collect_savers</code><span class="sig-paren">(</span><em class="sig-param">parent_path_suffix: str</em><span class="sig-paren">)</span> &#x2192; rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.collect_savers" title="Permalink to this definition"></a></dt>
<dd><p>Collection of all savers for the network (typically only one saver for network and one for ONNX export)
:param parent_path_suffix: path suffix of the parent of the network</p>
<blockquote>
@@ -369,9 +369,9 @@ of an identical network (either self or another identical network)</p></li>
</dl>
</dd></dl>
<dl class="staticmethod">
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.construct">
<em class="property">static </em><code class="descname">construct</code><span class="sig-paren">(</span><em>variable_scope: str, devices: List[str], *args, **kwargs</em><span class="sig-paren">)</span> &#x2192; rl_coach.architectures.architecture.Architecture<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.construct"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.construct" title="Permalink to this definition"></a></dt>
<em class="property">static </em><code class="sig-name descname">construct</code><span class="sig-paren">(</span><em class="sig-param">variable_scope: str, devices: List[str], *args, **kwargs</em><span class="sig-paren">)</span> &#x2192; rl_coach.architectures.architecture.Architecture<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.construct"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.construct" title="Permalink to this definition"></a></dt>
<dd><p>Construct a network class using the provided variable scope and on requested devices
:param variable_scope: string specifying variable scope under which to create network variables
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
@@ -382,7 +382,7 @@ of an identical network (either self or another identical network)</p></li>
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.get_variable_value">
<code class="descname">get_variable_value</code><span class="sig-paren">(</span><em>variable: Any</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.get_variable_value"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.get_variable_value" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_variable_value</code><span class="sig-paren">(</span><em class="sig-param">variable: Any</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.get_variable_value"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.get_variable_value" title="Permalink to this definition"></a></dt>
<dd><p>Gets value of a specified variable. Type of variable is dependant on the framework.
Example of a variable is head.kl_coefficient, which could be a symbol for evaluation
or could be a string representing the value.</p>
@@ -398,7 +398,7 @@ or could be a string representing the value.</p>
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.get_weights">
<code class="descname">get_weights</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; List[numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.get_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.get_weights" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_weights</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; List[numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.get_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.get_weights" title="Permalink to this definition"></a></dt>
<dd><p>Gets model weights as a list of ndarrays. It is used for synchronizing weight between two identical networks.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -407,9 +407,9 @@ or could be a string representing the value.</p>
</dl>
</dd></dl>
<dl class="staticmethod">
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.parallel_predict">
<em class="property">static </em><code class="descname">parallel_predict</code><span class="sig-paren">(</span><em>sess: Any, network_input_tuples: List[Tuple[Architecture, Dict[str, numpy.ndarray]]]</em><span class="sig-paren">)</span> &#x2192; Tuple[numpy.ndarray, ...]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.parallel_predict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.parallel_predict" title="Permalink to this definition"></a></dt>
<em class="property">static </em><code class="sig-name descname">parallel_predict</code><span class="sig-paren">(</span><em class="sig-param">sess: Any, network_input_tuples: List[Tuple[Architecture, Dict[str, numpy.ndarray]]]</em><span class="sig-paren">)</span> &#x2192; Tuple[numpy.ndarray, ...]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.parallel_predict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.parallel_predict" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -425,7 +425,7 @@ or could be a string representing the value.</p>
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.predict">
<code class="descname">predict</code><span class="sig-paren">(</span><em>inputs: Dict[str, numpy.ndarray], outputs: List[Any] = None, squeeze_output: bool = True, initial_feed_dict: Dict[Any, numpy.ndarray] = None</em><span class="sig-paren">)</span> &#x2192; Tuple[numpy.ndarray, ...]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.predict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.predict" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">predict</code><span class="sig-paren">(</span><em class="sig-param">inputs: Dict[str, numpy.ndarray], outputs: List[Any] = None, squeeze_output: bool = True, initial_feed_dict: Dict[Any, numpy.ndarray] = None</em><span class="sig-paren">)</span> &#x2192; Tuple[numpy.ndarray, ...]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.predict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.predict" title="Permalink to this definition"></a></dt>
<dd><p>Given input observations, use the model to make predictions (e.g. action or value).</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -446,7 +446,7 @@ depends on the framework backend.</p></li>
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.reset_accumulated_gradients">
<code class="descname">reset_accumulated_gradients</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.reset_accumulated_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.reset_accumulated_gradients" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">reset_accumulated_gradients</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.reset_accumulated_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.reset_accumulated_gradients" title="Permalink to this definition"></a></dt>
<dd><p>Sets gradient of all parameters to 0.</p>
<p>Once gradients are reset, they must be accessible by <cite>accumulated_gradients</cite> property of this class,
which must return a list of numpy ndarrays. Child class must ensure that <cite>accumulated_gradients</cite> is set.</p>
@@ -454,7 +454,7 @@ which must return a list of numpy ndarrays. Child class must ensure that <cite>a
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.set_variable_value">
<code class="descname">set_variable_value</code><span class="sig-paren">(</span><em>assign_op: Any</em>, <em>value: numpy.ndarray</em>, <em>placeholder: Any</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.set_variable_value"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.set_variable_value" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_variable_value</code><span class="sig-paren">(</span><em class="sig-param">assign_op: Any</em>, <em class="sig-param">value: numpy.ndarray</em>, <em class="sig-param">placeholder: Any</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.set_variable_value"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.set_variable_value" title="Permalink to this definition"></a></dt>
<dd><p>Updates the value of a specified variable. Type of assign_op is dependant on the framework
and is a unique identifier for assigning value to a variable. For example an agent may use
head.assign_kl_coefficient. There is a one to one mapping between assign_op and placeholder
@@ -472,7 +472,7 @@ head.assign_kl_coefficient. There is a one to one mapping between assign_op and
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.set_weights">
<code class="descname">set_weights</code><span class="sig-paren">(</span><em>weights: List[numpy.ndarray], rate: float = 1.0</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.set_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.set_weights" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_weights</code><span class="sig-paren">(</span><em class="sig-param">weights: List[numpy.ndarray], rate: float = 1.0</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.set_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.set_weights" title="Permalink to this definition"></a></dt>
<dd><p>Sets model weights for provided layer parameters.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -490,7 +490,7 @@ i.e. new_weight = rate * given_weight + (1 - rate) * old_weight</p></li>
<dl class="method">
<dt id="rl_coach.architectures.architecture.Architecture.train_on_batch">
<code class="descname">train_on_batch</code><span class="sig-paren">(</span><em>inputs: Dict[str, numpy.ndarray], targets: List[numpy.ndarray], scaler: float = 1.0, additional_fetches: list = None, importance_weights: numpy.ndarray = None</em><span class="sig-paren">)</span> &#x2192; Tuple[float, List[float], float, list]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.train_on_batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.train_on_batch" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">train_on_batch</code><span class="sig-paren">(</span><em class="sig-param">inputs: Dict[str, numpy.ndarray], targets: List[numpy.ndarray], scaler: float = 1.0, additional_fetches: list = None, importance_weights: numpy.ndarray = None</em><span class="sig-paren">)</span> &#x2192; Tuple[float, List[float], float, list]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.train_on_batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.train_on_batch" title="Permalink to this definition"></a></dt>
<dd><p>Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a
forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to
update the weights.
@@ -535,7 +535,7 @@ fetched_tensors: all values for additional_fetches</p>
<a class="reference internal image-reference" href="../../_images/distributed.png"><img alt="../../_images/distributed.png" class="align-center" src="../../_images/distributed.png" style="width: 600px;" /></a>
<dl class="class">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper">
<em class="property">class </em><code class="descclassname">rl_coach.architectures.network_wrapper.</code><code class="descname">NetworkWrapper</code><span class="sig-paren">(</span><em>agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em>has_target: bool</em>, <em>has_global: bool</em>, <em>name: str</em>, <em>spaces: rl_coach.spaces.SpacesDefinition</em>, <em>replicated_device=None</em>, <em>worker_device=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.architectures.network_wrapper.</code><code class="sig-name descname">NetworkWrapper</code><span class="sig-paren">(</span><em class="sig-param">agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em class="sig-param">has_target: bool</em>, <em class="sig-param">has_global: bool</em>, <em class="sig-param">name: str</em>, <em class="sig-param">spaces: rl_coach.spaces.SpacesDefinition</em>, <em class="sig-param">replicated_device=None</em>, <em class="sig-param">worker_device=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper" title="Permalink to this definition"></a></dt>
<dd><p>The network wrapper contains multiple copies of the same network, each one with a different set of weights which is
updating in a different time scale. The network wrapper will always contain an online network.
It will contain an additional slow updating target network if it was requested by the user,
@@ -544,7 +544,7 @@ multi-process distributed mode. The network wrapper contains functionality for m
between them.</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">
<code class="descname">apply_gradients_and_sync_networks</code><span class="sig-paren">(</span><em>reset_gradients=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">apply_gradients_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">reset_gradients=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks" title="Permalink to this definition"></a></dt>
<dd><p>Applies the gradients accumulated in the online network to the global network or to itself and syncs the
networks if necessary</p>
<dl class="field-list simple">
@@ -559,7 +559,7 @@ complexity for this function by around 10%</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">
<code class="descname">apply_gradients_to_global_network</code><span class="sig-paren">(</span><em>gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_global_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">apply_gradients_to_global_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_global_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network" title="Permalink to this definition"></a></dt>
<dd><p>Apply gradients from the online network on the global network</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -573,7 +573,7 @@ complexity for this function by around 10%</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">
<code class="descname">apply_gradients_to_online_network</code><span class="sig-paren">(</span><em>gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">apply_gradients_to_online_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network" title="Permalink to this definition"></a></dt>
<dd><p>Apply gradients from the online network on itself</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -584,7 +584,7 @@ complexity for this function by around 10%</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.collect_savers">
<code class="descname">collect_savers</code><span class="sig-paren">(</span><em>parent_path_suffix: str</em><span class="sig-paren">)</span> &#x2192; rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.collect_savers" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">collect_savers</code><span class="sig-paren">(</span><em class="sig-param">parent_path_suffix: str</em><span class="sig-paren">)</span> &#x2192; rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.collect_savers" title="Permalink to this definition"></a></dt>
<dd><p>Collect all of networks savers for global or online network
Note: global, online, and target network are all copies fo the same network which parameters that are</p>
<blockquote>
@@ -610,7 +610,7 @@ for saving.</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.parallel_prediction">
<code class="descname">parallel_prediction</code><span class="sig-paren">(</span><em>network_input_tuples: List[Tuple]</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.parallel_prediction"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.parallel_prediction" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">parallel_prediction</code><span class="sig-paren">(</span><em class="sig-param">network_input_tuples: List[Tuple]</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.parallel_prediction"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.parallel_prediction" title="Permalink to this definition"></a></dt>
<dd><p>Run several network prediction in parallel. Currently this only supports running each of the network once.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -625,7 +625,7 @@ target_network or global_network) and the second element is the inputs</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.set_is_training">
<code class="descname">set_is_training</code><span class="sig-paren">(</span><em>state: bool</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.set_is_training"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.set_is_training" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_is_training</code><span class="sig-paren">(</span><em class="sig-param">state: bool</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.set_is_training"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.set_is_training" title="Permalink to this definition"></a></dt>
<dd><p>Set the phase of the network between training and testing</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -639,7 +639,7 @@ target_network or global_network) and the second element is the inputs</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.sync">
<code class="descname">sync</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.sync"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.sync" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">sync</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.sync"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.sync" title="Permalink to this definition"></a></dt>
<dd><p>Initializes the weights of the networks to match each other</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -650,7 +650,7 @@ target_network or global_network) and the second element is the inputs</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">
<code class="descname">train_and_sync_networks</code><span class="sig-paren">(</span><em>inputs</em>, <em>targets</em>, <em>additional_fetches=[]</em>, <em>importance_weights=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.train_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">train_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">inputs</em>, <em class="sig-param">targets</em>, <em class="sig-param">additional_fetches=[]</em>, <em class="sig-param">importance_weights=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.train_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks" title="Permalink to this definition"></a></dt>
<dd><p>A generic training function that enables multi-threading training using a global network if necessary.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -670,7 +670,7 @@ error of this sample. If it is not given, the samples losses wont be scaled</
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.update_online_network">
<code class="descname">update_online_network</code><span class="sig-paren">(</span><em>rate=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.update_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.update_online_network" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">update_online_network</code><span class="sig-paren">(</span><em class="sig-param">rate=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.update_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.update_online_network" title="Permalink to this definition"></a></dt>
<dd><p>Copy weights: global network &gt;&gt;&gt; online network</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -681,7 +681,7 @@ error of this sample. If it is not given, the samples losses wont be scaled</
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.update_target_network">
<code class="descname">update_target_network</code><span class="sig-paren">(</span><em>rate=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.update_target_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.update_target_network" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">update_target_network</code><span class="sig-paren">(</span><em class="sig-param">rate=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.update_target_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.update_target_network" title="Permalink to this definition"></a></dt>
<dd><p>Copy weights: online network &gt;&gt;&gt; target network</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
+26 -26
View File
@@ -197,7 +197,7 @@
<h2>ActionInfo<a class="headerlink" href="#actioninfo" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.core_types.ActionInfo">
<em class="property">class </em><code class="descclassname">rl_coach.core_types.</code><code class="descname">ActionInfo</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List], all_action_probabilities: float = 0, action_value: float = 0.0, state_value: float = 0.0, max_action_value: float = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#ActionInfo"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.ActionInfo" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.core_types.</code><code class="sig-name descname">ActionInfo</code><span class="sig-paren">(</span><em class="sig-param">action: Union[int, float, numpy.ndarray, List], all_action_probabilities: float = 0, action_value: float = 0.0, state_value: float = 0.0, max_action_value: float = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#ActionInfo"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.ActionInfo" title="Permalink to this definition"></a></dt>
<dd><p>Action info is a class that holds an action and various additional information details about it</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -219,7 +219,7 @@ action with the maximum value</p></li>
<h2>Batch<a class="headerlink" href="#batch" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.core_types.Batch">
<em class="property">class </em><code class="descclassname">rl_coach.core_types.</code><code class="descname">Batch</code><span class="sig-paren">(</span><em>transitions: List[rl_coach.core_types.Transition]</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.core_types.</code><code class="sig-name descname">Batch</code><span class="sig-paren">(</span><em class="sig-param">transitions: List[rl_coach.core_types.Transition]</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch" title="Permalink to this definition"></a></dt>
<dd><p>A wrapper around a list of transitions that helps extracting batches of parameters from it.
For example, one can extract a list of states corresponding to the list of transitions.
The class uses lazy evaluation in order to return each of the available parameters.</p>
@@ -230,7 +230,7 @@ The class uses lazy evaluation in order to return each of the available paramete
</dl>
<dl class="method">
<dt id="rl_coach.core_types.Batch.actions">
<code class="descname">actions</code><span class="sig-paren">(</span><em>expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.actions"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.actions" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">actions</code><span class="sig-paren">(</span><em class="sig-param">expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.actions"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.actions" title="Permalink to this definition"></a></dt>
<dd><p>if the actions were not converted to a batch before, extract them to a batch and then return the batch</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -244,7 +244,7 @@ The class uses lazy evaluation in order to return each of the available paramete
<dl class="method">
<dt id="rl_coach.core_types.Batch.game_overs">
<code class="descname">game_overs</code><span class="sig-paren">(</span><em>expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.game_overs"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.game_overs" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">game_overs</code><span class="sig-paren">(</span><em class="sig-param">expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.game_overs"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.game_overs" title="Permalink to this definition"></a></dt>
<dd><p>if the game_overs were not converted to a batch before, extract them to a batch and then return the batch</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -258,7 +258,7 @@ The class uses lazy evaluation in order to return each of the available paramete
<dl class="method">
<dt id="rl_coach.core_types.Batch.goals">
<code class="descname">goals</code><span class="sig-paren">(</span><em>expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.goals"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.goals" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">goals</code><span class="sig-paren">(</span><em class="sig-param">expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.goals"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.goals" title="Permalink to this definition"></a></dt>
<dd><p>if the goals were not converted to a batch before, extract them to a batch and then return the batch
if the goal was not filled, this will raise an exception</p>
<dl class="field-list simple">
@@ -273,7 +273,7 @@ if the goal was not filled, this will raise an exception</p>
<dl class="method">
<dt id="rl_coach.core_types.Batch.info">
<code class="descname">info</code><span class="sig-paren">(</span><em>key</em>, <em>expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.info"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.info" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">info</code><span class="sig-paren">(</span><em class="sig-param">key</em>, <em class="sig-param">expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.info"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.info" title="Permalink to this definition"></a></dt>
<dd><p>if the given info dictionary key was not converted to a batch before, extract it to a batch and then return the
batch. if the key is not part of the keys in the info dictionary, this will raise an exception</p>
<dl class="field-list simple">
@@ -288,7 +288,7 @@ batch. if the key is not part of the keys in the info dictionary, this will rais
<dl class="method">
<dt id="rl_coach.core_types.Batch.info_as_list">
<code class="descname">info_as_list</code><span class="sig-paren">(</span><em>key</em><span class="sig-paren">)</span> &#x2192; list<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.info_as_list"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.info_as_list" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">info_as_list</code><span class="sig-paren">(</span><em class="sig-param">key</em><span class="sig-paren">)</span> &#x2192; list<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.info_as_list"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.info_as_list" title="Permalink to this definition"></a></dt>
<dd><p>get the info and store it internally as a list, if wasnt stored before. return it as a list
:param expand_dims: add an extra dimension to the info batch
:return: a list containing all the info values of the batch corresponding to the given key</p>
@@ -296,7 +296,7 @@ batch. if the key is not part of the keys in the info dictionary, this will rais
<dl class="method">
<dt id="rl_coach.core_types.Batch.n_step_discounted_rewards">
<code class="descname">n_step_discounted_rewards</code><span class="sig-paren">(</span><em>expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.n_step_discounted_rewards"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.n_step_discounted_rewards" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">n_step_discounted_rewards</code><span class="sig-paren">(</span><em class="sig-param">expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.n_step_discounted_rewards"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.n_step_discounted_rewards" title="Permalink to this definition"></a></dt>
<dd><dl class="simple">
<dt>if the n_step_discounted_rewards were not converted to a batch before, extract them to a batch and then return</dt><dd><p>the batch</p>
</dd>
@@ -308,7 +308,7 @@ batch. if the key is not part of the keys in the info dictionary, this will rais
<dl class="method">
<dt id="rl_coach.core_types.Batch.next_states">
<code class="descname">next_states</code><span class="sig-paren">(</span><em>fetches: List[str], expand_dims=False</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.next_states"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.next_states" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">next_states</code><span class="sig-paren">(</span><em class="sig-param">fetches: List[str], expand_dims=False</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.next_states"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.next_states" title="Permalink to this definition"></a></dt>
<dd><p>follow the keys in fetches to extract the corresponding items from the next states in the batch
if these keys were not already extracted before. return only the values corresponding to those keys</p>
<dl class="field-list simple">
@@ -326,7 +326,7 @@ if these keys were not already extracted before. return only the values correspo
<dl class="method">
<dt id="rl_coach.core_types.Batch.rewards">
<code class="descname">rewards</code><span class="sig-paren">(</span><em>expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.rewards"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.rewards" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">rewards</code><span class="sig-paren">(</span><em class="sig-param">expand_dims=False</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.rewards"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.rewards" title="Permalink to this definition"></a></dt>
<dd><p>if the rewards were not converted to a batch before, extract them to a batch and then return the batch</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -340,7 +340,7 @@ if these keys were not already extracted before. return only the values correspo
<dl class="method">
<dt id="rl_coach.core_types.Batch.shuffle">
<code class="descname">shuffle</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.shuffle"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.shuffle" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">shuffle</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.shuffle"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.shuffle" title="Permalink to this definition"></a></dt>
<dd><p>Shuffle all the transitions in the batch</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -349,9 +349,9 @@ if these keys were not already extracted before. return only the values correspo
</dl>
</dd></dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.core_types.Batch.size">
<code class="descname">size</code><a class="headerlink" href="#rl_coach.core_types.Batch.size" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">size</code><a class="headerlink" href="#rl_coach.core_types.Batch.size" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>the size of the batch</p>
@@ -361,7 +361,7 @@ if these keys were not already extracted before. return only the values correspo
<dl class="method">
<dt id="rl_coach.core_types.Batch.slice">
<code class="descname">slice</code><span class="sig-paren">(</span><em>start</em>, <em>end</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.slice"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.slice" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">slice</code><span class="sig-paren">(</span><em class="sig-param">start</em>, <em class="sig-param">end</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.slice"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.slice" title="Permalink to this definition"></a></dt>
<dd><p>Keep a slice from the batch and discard the rest of the batch</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -378,7 +378,7 @@ if these keys were not already extracted before. return only the values correspo
<dl class="method">
<dt id="rl_coach.core_types.Batch.states">
<code class="descname">states</code><span class="sig-paren">(</span><em>fetches: List[str], expand_dims=False</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.states"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.states" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">states</code><span class="sig-paren">(</span><em class="sig-param">fetches: List[str], expand_dims=False</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="reference internal" href="../_modules/rl_coach/core_types.html#Batch.states"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Batch.states" title="Permalink to this definition"></a></dt>
<dd><p>follow the keys in fetches to extract the corresponding items from the states in the batch
if these keys were not already extracted before. return only the values corresponding to those keys</p>
<dl class="field-list simple">
@@ -401,7 +401,7 @@ if these keys were not already extracted before. return only the values correspo
<h2>EnvResponse<a class="headerlink" href="#envresponse" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.core_types.EnvResponse">
<em class="property">class </em><code class="descclassname">rl_coach.core_types.</code><code class="descname">EnvResponse</code><span class="sig-paren">(</span><em>next_state: Dict[str, numpy.ndarray], reward: Union[int, float, numpy.ndarray], game_over: bool, info: Dict = None, goal: numpy.ndarray = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#EnvResponse"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.EnvResponse" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.core_types.</code><code class="sig-name descname">EnvResponse</code><span class="sig-paren">(</span><em class="sig-param">next_state: Dict[str, numpy.ndarray], reward: Union[int, float, numpy.ndarray], game_over: bool, info: Dict = None, goal: numpy.ndarray = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#EnvResponse"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.EnvResponse" title="Permalink to this definition"></a></dt>
<dd><p>An env response is a collection containing the information returning from the environment after a single action
has been performed on it.</p>
<dl class="field-list simple">
@@ -424,7 +424,7 @@ the execution of the action.</p></li>
<h2>Episode<a class="headerlink" href="#episode" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.core_types.Episode">
<em class="property">class </em><code class="descclassname">rl_coach.core_types.</code><code class="descname">Episode</code><span class="sig-paren">(</span><em>discount: float = 0.99</em>, <em>bootstrap_total_return_from_old_policy: bool = False</em>, <em>n_step: int = -1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.core_types.</code><code class="sig-name descname">Episode</code><span class="sig-paren">(</span><em class="sig-param">discount: float = 0.99</em>, <em class="sig-param">bootstrap_total_return_from_old_policy: bool = False</em>, <em class="sig-param">n_step: int = -1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode" title="Permalink to this definition"></a></dt>
<dd><p>An Episode represents a set of sequential transitions, that end with a terminal state.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -438,7 +438,7 @@ memory</p></li>
</dl>
<dl class="method">
<dt id="rl_coach.core_types.Episode.get_first_transition">
<code class="descname">get_first_transition</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.get_first_transition"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.get_first_transition" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_first_transition</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.get_first_transition"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.get_first_transition" title="Permalink to this definition"></a></dt>
<dd><p>Get the first transition in the episode, or None if there are no transitions available</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -449,7 +449,7 @@ memory</p></li>
<dl class="method">
<dt id="rl_coach.core_types.Episode.get_last_transition">
<code class="descname">get_last_transition</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.get_last_transition"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.get_last_transition" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_last_transition</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.get_last_transition"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.get_last_transition" title="Permalink to this definition"></a></dt>
<dd><p>Get the last transition in the episode, or None if there are no transition available</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -460,7 +460,7 @@ memory</p></li>
<dl class="method">
<dt id="rl_coach.core_types.Episode.get_transition">
<code class="descname">get_transition</code><span class="sig-paren">(</span><em>transition_idx: int</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.get_transition"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.get_transition" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_transition</code><span class="sig-paren">(</span><em class="sig-param">transition_idx: int</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.get_transition"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.get_transition" title="Permalink to this definition"></a></dt>
<dd><p>Get a specific transition by its index.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -474,7 +474,7 @@ memory</p></li>
<dl class="method">
<dt id="rl_coach.core_types.Episode.get_transitions_attribute">
<code class="descname">get_transitions_attribute</code><span class="sig-paren">(</span><em>attribute_name: str</em><span class="sig-paren">)</span> &#x2192; List[Any]<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.get_transitions_attribute"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.get_transitions_attribute" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_transitions_attribute</code><span class="sig-paren">(</span><em class="sig-param">attribute_name: str</em><span class="sig-paren">)</span> &#x2192; List[Any]<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.get_transitions_attribute"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.get_transitions_attribute" title="Permalink to this definition"></a></dt>
<dd><p>Get the values for some transition attribute from all the transitions in the episode.
For example, this allows getting the rewards for all the transitions as a list by calling
get_transitions_attribute(reward)</p>
@@ -490,7 +490,7 @@ get_transitions_attribute(reward)</p>
<dl class="method">
<dt id="rl_coach.core_types.Episode.insert">
<code class="descname">insert</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.insert"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.insert" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">insert</code><span class="sig-paren">(</span><em class="sig-param">transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.insert"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.insert" title="Permalink to this definition"></a></dt>
<dd><p>Insert a new transition to the episode. If the game_over flag in the transition is set to True,
the episode will be marked as complete.</p>
<dl class="field-list simple">
@@ -505,7 +505,7 @@ the episode will be marked as complete.</p>
<dl class="method">
<dt id="rl_coach.core_types.Episode.is_empty">
<code class="descname">is_empty</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.is_empty"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.is_empty" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">is_empty</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.is_empty"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.is_empty" title="Permalink to this definition"></a></dt>
<dd><p>Check if the episode is empty</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -516,7 +516,7 @@ the episode will be marked as complete.</p>
<dl class="method">
<dt id="rl_coach.core_types.Episode.length">
<code class="descname">length</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; int<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.length"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.length" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">length</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; int<a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.length"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.length" title="Permalink to this definition"></a></dt>
<dd><p>Return the length of the episode, which is the number of transitions it holds.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -527,7 +527,7 @@ the episode will be marked as complete.</p>
<dl class="method">
<dt id="rl_coach.core_types.Episode.update_discounted_rewards">
<code class="descname">update_discounted_rewards</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.update_discounted_rewards"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.update_discounted_rewards" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">update_discounted_rewards</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#Episode.update_discounted_rewards"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Episode.update_discounted_rewards" title="Permalink to this definition"></a></dt>
<dd><p>Update the discounted returns for all the transitions in the episode.
The returns will be calculated according to the rewards of each transition, together with the number of steps
to bootstrap from and the discount factor, as defined by n_step and discount respectively when initializing
@@ -546,7 +546,7 @@ the episode.</p>
<h2>Transition<a class="headerlink" href="#transition" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.core_types.Transition">
<em class="property">class </em><code class="descclassname">rl_coach.core_types.</code><code class="descname">Transition</code><span class="sig-paren">(</span><em>state: Dict[str</em>, <em>numpy.ndarray] = None</em>, <em>action: Union[int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List] = None</em>, <em>reward: Union[int</em>, <em>float</em>, <em>numpy.ndarray] = None</em>, <em>next_state: Dict[str</em>, <em>numpy.ndarray] = None</em>, <em>game_over: bool = None</em>, <em>info: Dict = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#Transition"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Transition" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.core_types.</code><code class="sig-name descname">Transition</code><span class="sig-paren">(</span><em class="sig-param">state: Dict[str</em>, <em class="sig-param">numpy.ndarray] = None</em>, <em class="sig-param">action: Union[int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray</em>, <em class="sig-param">List] = None</em>, <em class="sig-param">reward: Union[int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray] = None</em>, <em class="sig-param">next_state: Dict[str</em>, <em class="sig-param">numpy.ndarray] = None</em>, <em class="sig-param">game_over: bool = None</em>, <em class="sig-param">info: Dict = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#Transition"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.Transition" title="Permalink to this definition"></a></dt>
<dd><p>A transition is a tuple containing the information of a single step of interaction
between the agent and the environment. The most basic version should contain the following values:
(current state, action, reward, next state, game over)
+2 -2
View File
@@ -194,7 +194,7 @@
<h2>S3DataStore<a class="headerlink" href="#s3datastore" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.data_stores.s3_data_store.S3DataStore">
<em class="property">class </em><code class="descclassname">rl_coach.data_stores.s3_data_store.</code><code class="descname">S3DataStore</code><span class="sig-paren">(</span><em>params: rl_coach.data_stores.s3_data_store.S3DataStoreParameters</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/data_stores/s3_data_store.html#S3DataStore"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.data_stores.s3_data_store.S3DataStore" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.data_stores.s3_data_store.</code><code class="sig-name descname">S3DataStore</code><span class="sig-paren">(</span><em class="sig-param">params: rl_coach.data_stores.s3_data_store.S3DataStoreParameters</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/data_stores/s3_data_store.html#S3DataStore"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.data_stores.s3_data_store.S3DataStore" title="Permalink to this definition"></a></dt>
<dd><p>An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.
The policy checkpoints are written by the trainer and read by the rollout worker.</p>
<dl class="field-list simple">
@@ -209,7 +209,7 @@ The policy checkpoints are written by the trainer and read by the rollout worker
<h2>NFSDataStore<a class="headerlink" href="#nfsdatastore" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.data_stores.nfs_data_store.NFSDataStore">
<em class="property">class </em><code class="descclassname">rl_coach.data_stores.nfs_data_store.</code><code class="descname">NFSDataStore</code><span class="sig-paren">(</span><em>params: rl_coach.data_stores.nfs_data_store.NFSDataStoreParameters</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/data_stores/nfs_data_store.html#NFSDataStore"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.data_stores.nfs_data_store.NFSDataStore" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.data_stores.nfs_data_store.</code><code class="sig-name descname">NFSDataStore</code><span class="sig-paren">(</span><em class="sig-param">params: rl_coach.data_stores.nfs_data_store.NFSDataStoreParameters</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/data_stores/nfs_data_store.html#NFSDataStore"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.data_stores.nfs_data_store.NFSDataStore" title="Permalink to this definition"></a></dt>
<dd><p>An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.
The policy checkpoints are written by the trainer and read by the rollout worker.</p>
<dl class="field-list simple">
+27 -27
View File
@@ -195,7 +195,7 @@
<h1>Environments<a class="headerlink" href="#environments" title="Permalink to this headline"></a></h1>
<dl class="class">
<dt id="rl_coach.environments.environment.Environment">
<em class="property">class </em><code class="descclassname">rl_coach.environments.environment.</code><code class="descname">Environment</code><span class="sig-paren">(</span><em>level: rl_coach.environments.environment.LevelSelection, seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float], visualization_parameters: rl_coach.base_parameters.VisualizationParameters, target_success_rate: float = 1.0, **kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.environments.environment.</code><code class="sig-name descname">Environment</code><span class="sig-paren">(</span><em class="sig-param">level: rl_coach.environments.environment.LevelSelection, seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float], visualization_parameters: rl_coach.base_parameters.VisualizationParameters, target_success_rate: float = 1.0, **kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -210,9 +210,9 @@ additional arguments which will be ignored by this class, but might be used by o
</ul>
</dd>
</dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.action_space">
<code class="descname">action_space</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.action_space" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">action_space</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.action_space" title="Permalink to this definition"></a></dt>
<dd><p>Get the action space of the environment</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -223,7 +223,7 @@ additional arguments which will be ignored by this class, but might be used by o
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.close">
<code class="descname">close</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.close"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.close" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">close</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.close"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.close" title="Permalink to this definition"></a></dt>
<dd><p>Clean up steps.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -234,7 +234,7 @@ additional arguments which will be ignored by this class, but might be used by o
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.get_action_from_user">
<code class="descname">get_action_from_user</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_action_from_user"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_action_from_user" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_action_from_user</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_action_from_user"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_action_from_user" title="Permalink to this definition"></a></dt>
<dd><p>Get an action from the user keyboard</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -245,7 +245,7 @@ additional arguments which will be ignored by this class, but might be used by o
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.get_available_keys">
<code class="descname">get_available_keys</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; List[Tuple[str, Union[int, float, numpy.ndarray, List]]]<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_available_keys"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_available_keys" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_available_keys</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; List[Tuple[str, Union[int, float, numpy.ndarray, List]]]<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_available_keys"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_available_keys" title="Permalink to this definition"></a></dt>
<dd><p>Return a list of tuples mapping between action names and the keyboard key that triggers them</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -256,7 +256,7 @@ additional arguments which will be ignored by this class, but might be used by o
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.get_goal">
<code class="descname">get_goal</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Union[None, numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_goal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_goal" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_goal</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Union[None, numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_goal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_goal" title="Permalink to this definition"></a></dt>
<dd><p>Get the current goal that the agents needs to achieve in the environment</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -267,7 +267,7 @@ additional arguments which will be ignored by this class, but might be used by o
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.get_random_action">
<code class="descname">get_random_action</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_random_action"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_random_action" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_random_action</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_random_action"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_random_action" title="Permalink to this definition"></a></dt>
<dd><p>Returns an action picked uniformly from the available actions</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -278,7 +278,7 @@ additional arguments which will be ignored by this class, but might be used by o
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.get_rendered_image">
<code class="descname">get_rendered_image</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_rendered_image"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_rendered_image" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_rendered_image</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.get_rendered_image"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.get_rendered_image" title="Permalink to this definition"></a></dt>
<dd><p>Return a numpy array containing the image that will be rendered to the screen.
This can be different from the observation. For example, mujocos observation is a measurements vector.</p>
<dl class="field-list simple">
@@ -288,9 +288,9 @@ This can be different from the observation. For example, mujocos observation
</dl>
</dd></dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.goal_space">
<code class="descname">goal_space</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.goal_space" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">goal_space</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.goal_space" title="Permalink to this definition"></a></dt>
<dd><p>Get the state space of the environment</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -301,7 +301,7 @@ This can be different from the observation. For example, mujocos observation
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.handle_episode_ended">
<code class="descname">handle_episode_ended</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.handle_episode_ended"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.handle_episode_ended" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">handle_episode_ended</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.handle_episode_ended"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.handle_episode_ended" title="Permalink to this definition"></a></dt>
<dd><p>End an episode</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -310,9 +310,9 @@ This can be different from the observation. For example, mujocos observation
</dl>
</dd></dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.last_env_response">
<code class="descname">last_env_response</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.last_env_response" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">last_env_response</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.last_env_response" title="Permalink to this definition"></a></dt>
<dd><p>Get the last environment response</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -321,16 +321,16 @@ This can be different from the observation. For example, mujocos observation
</dl>
</dd></dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.phase">
<code class="descname">phase</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.phase" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">phase</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.phase" title="Permalink to this definition"></a></dt>
<dd><p>Get the phase of the environment
:return: the current phase</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.render">
<code class="descname">render</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.render"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.render" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">render</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.render"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.render" title="Permalink to this definition"></a></dt>
<dd><p>Call the environment function for rendering to the screen</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -341,7 +341,7 @@ This can be different from the observation. For example, mujocos observation
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.reset_internal_state">
<code class="descname">reset_internal_state</code><span class="sig-paren">(</span><em>force_environment_reset=False</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.EnvResponse<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.reset_internal_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.reset_internal_state" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">reset_internal_state</code><span class="sig-paren">(</span><em class="sig-param">force_environment_reset=False</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.EnvResponse<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.reset_internal_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.reset_internal_state" title="Permalink to this definition"></a></dt>
<dd><p>Reset the environment and all the variable of the wrapper</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -355,7 +355,7 @@ This can be different from the observation. For example, mujocos observation
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.set_goal">
<code class="descname">set_goal</code><span class="sig-paren">(</span><em>goal: Union[None, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.set_goal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.set_goal" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_goal</code><span class="sig-paren">(</span><em class="sig-param">goal: Union[None, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; None<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.set_goal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.set_goal" title="Permalink to this definition"></a></dt>
<dd><p>Set the current goal that the agent needs to achieve in the environment</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -367,9 +367,9 @@ This can be different from the observation. For example, mujocos observation
</dl>
</dd></dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.state_space">
<code class="descname">state_space</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.state_space" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">state_space</code><a class="headerlink" href="#rl_coach.environments.environment.Environment.state_space" title="Permalink to this definition"></a></dt>
<dd><p>Get the state space of the environment</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -380,7 +380,7 @@ This can be different from the observation. For example, mujocos observation
<dl class="method">
<dt id="rl_coach.environments.environment.Environment.step">
<code class="descname">step</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.EnvResponse<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.step"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.step" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">step</code><span class="sig-paren">(</span><em class="sig-param">action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.EnvResponse<a class="reference internal" href="../../_modules/rl_coach/environments/environment.html#Environment.step"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.environment.Environment.step" title="Permalink to this definition"></a></dt>
<dd><p>Make a single step in the environment using the given action</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -400,7 +400,7 @@ This can be different from the observation. For example, mujocos observation
<p>Website: <a class="reference external" href="https://github.com/deepmind/dm_control">DeepMind Control Suite</a></p>
<dl class="class">
<dt id="rl_coach.environments.control_suite_environment.ControlSuiteEnvironment">
<em class="property">class </em><code class="descclassname">rl_coach.environments.control_suite_environment.</code><code class="descname">ControlSuiteEnvironment</code><span class="sig-paren">(</span><em>level: rl_coach.environments.environment.LevelSelection</em>, <em>frame_skip: int</em>, <em>visualization_parameters: rl_coach.base_parameters.VisualizationParameters</em>, <em>target_success_rate: float = 1.0</em>, <em>seed: Union[None</em>, <em>int] = None</em>, <em>human_control: bool = False</em>, <em>observation_type: rl_coach.environments.control_suite_environment.ObservationType = &lt;ObservationType.Measurements: 1&gt;</em>, <em>custom_reward_threshold: Union[int</em>, <em>float] = None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/control_suite_environment.html#ControlSuiteEnvironment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.control_suite_environment.ControlSuiteEnvironment" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.environments.control_suite_environment.</code><code class="sig-name descname">ControlSuiteEnvironment</code><span class="sig-paren">(</span><em class="sig-param">level: rl_coach.environments.environment.LevelSelection</em>, <em class="sig-param">frame_skip: int</em>, <em class="sig-param">visualization_parameters: rl_coach.base_parameters.VisualizationParameters</em>, <em class="sig-param">target_success_rate: float = 1.0</em>, <em class="sig-param">seed: Union[None</em>, <em class="sig-param">int] = None</em>, <em class="sig-param">human_control: bool = False</em>, <em class="sig-param">observation_type: rl_coach.environments.control_suite_environment.ObservationType = &lt;ObservationType.Measurements: 1&gt;</em>, <em class="sig-param">custom_reward_threshold: Union[int</em>, <em class="sig-param">float] = None</em>, <em class="sig-param">**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/control_suite_environment.html#ControlSuiteEnvironment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.control_suite_environment.ControlSuiteEnvironment" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -438,7 +438,7 @@ Allows defining a custom reward that will be used to decide when the agent succe
<p>Website: <a class="reference external" href="https://github.com/deepmind/pysc2">Blizzard Starcraft II</a></p>
<dl class="class">
<dt id="rl_coach.environments.starcraft2_environment.StarCraft2Environment">
<em class="property">class </em><code class="descclassname">rl_coach.environments.starcraft2_environment.</code><code class="descname">StarCraft2Environment</code><span class="sig-paren">(</span><em>level: rl_coach.environments.environment.LevelSelection</em>, <em>frame_skip: int</em>, <em>visualization_parameters: rl_coach.base_parameters.VisualizationParameters</em>, <em>target_success_rate: float = 1.0</em>, <em>seed: Union[None</em>, <em>int] = None</em>, <em>human_control: bool = False</em>, <em>custom_reward_threshold: Union[int</em>, <em>float] = None</em>, <em>screen_size: int = 84</em>, <em>minimap_size: int = 64</em>, <em>feature_minimap_maps_to_use: List = range(0</em>, <em>7)</em>, <em>feature_screen_maps_to_use: List = range(0</em>, <em>17)</em>, <em>observation_type: rl_coach.environments.starcraft2_environment.StarcraftObservationType = &lt;StarcraftObservationType.Features: 0&gt;</em>, <em>disable_fog: bool = False</em>, <em>auto_select_all_army: bool = True</em>, <em>use_full_action_space: bool = False</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/starcraft2_environment.html#StarCraft2Environment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.starcraft2_environment.StarCraft2Environment" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.environments.starcraft2_environment.</code><code class="sig-name descname">StarCraft2Environment</code><span class="sig-paren">(</span><em class="sig-param">level: rl_coach.environments.environment.LevelSelection</em>, <em class="sig-param">frame_skip: int</em>, <em class="sig-param">visualization_parameters: rl_coach.base_parameters.VisualizationParameters</em>, <em class="sig-param">target_success_rate: float = 1.0</em>, <em class="sig-param">seed: Union[None</em>, <em class="sig-param">int] = None</em>, <em class="sig-param">human_control: bool = False</em>, <em class="sig-param">custom_reward_threshold: Union[int</em>, <em class="sig-param">float] = None</em>, <em class="sig-param">screen_size: int = 84</em>, <em class="sig-param">minimap_size: int = 64</em>, <em class="sig-param">feature_minimap_maps_to_use: List = range(0</em>, <em class="sig-param">7)</em>, <em class="sig-param">feature_screen_maps_to_use: List = range(0</em>, <em class="sig-param">17)</em>, <em class="sig-param">observation_type: rl_coach.environments.starcraft2_environment.StarcraftObservationType = &lt;StarcraftObservationType.Features: 0&gt;</em>, <em class="sig-param">disable_fog: bool = False</em>, <em class="sig-param">auto_select_all_army: bool = True</em>, <em class="sig-param">use_full_action_space: bool = False</em>, <em class="sig-param">**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/starcraft2_environment.html#StarCraft2Environment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.starcraft2_environment.StarCraft2Environment" title="Permalink to this definition"></a></dt>
<dd></dd></dl>
</div>
@@ -448,7 +448,7 @@ Allows defining a custom reward that will be used to decide when the agent succe
<p>Website: <a class="reference external" href="http://vizdoom.cs.put.edu.pl/">ViZDoom</a></p>
<dl class="class">
<dt id="rl_coach.environments.doom_environment.DoomEnvironment">
<em class="property">class </em><code class="descclassname">rl_coach.environments.doom_environment.</code><code class="descname">DoomEnvironment</code><span class="sig-paren">(</span><em>level: rl_coach.environments.environment.LevelSelection, seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float], visualization_parameters: rl_coach.base_parameters.VisualizationParameters, cameras: List[rl_coach.environments.doom_environment.DoomEnvironment.CameraTypes], target_success_rate: float = 1.0, **kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/doom_environment.html#DoomEnvironment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.doom_environment.DoomEnvironment" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.environments.doom_environment.</code><code class="sig-name descname">DoomEnvironment</code><span class="sig-paren">(</span><em class="sig-param">level: rl_coach.environments.environment.LevelSelection, seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float], visualization_parameters: rl_coach.base_parameters.VisualizationParameters, cameras: List[rl_coach.environments.doom_environment.DoomEnvironment.CameraTypes], target_success_rate: float = 1.0, **kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/doom_environment.html#DoomEnvironment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.doom_environment.DoomEnvironment" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -491,7 +491,7 @@ Stop experiment if given target success rate was achieved.</p>
<p>Website: <a class="reference external" href="https://github.com/carla-simulator/carla">CARLA</a></p>
<dl class="class">
<dt id="rl_coach.environments.carla_environment.CarlaEnvironment">
<em class="property">class </em><code class="descclassname">rl_coach.environments.carla_environment.</code><code class="descname">CarlaEnvironment</code><span class="sig-paren">(</span><em>level: rl_coach.environments.environment.LevelSelection, seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float], visualization_parameters: rl_coach.base_parameters.VisualizationParameters, server_height: int, server_width: int, camera_height: int, camera_width: int, verbose: bool, experiment_suite: carla.driving_benchmark.experiment_suites.experiment_suite.ExperimentSuite, config: str, episode_max_time: int, allow_braking: bool, quality: rl_coach.environments.carla_environment.CarlaEnvironmentParameters.Quality, cameras: List[rl_coach.environments.carla_environment.CameraTypes], weather_id: List[int], experiment_path: str, separate_actions_for_throttle_and_brake: bool, num_speedup_steps: int, max_speed: float, target_success_rate: float = 1.0, **kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/carla_environment.html#CarlaEnvironment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.carla_environment.CarlaEnvironment" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.environments.carla_environment.</code><code class="sig-name descname">CarlaEnvironment</code><span class="sig-paren">(</span><em class="sig-param">level: rl_coach.environments.environment.LevelSelection, seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float], visualization_parameters: rl_coach.base_parameters.VisualizationParameters, server_height: int, server_width: int, camera_height: int, camera_width: int, verbose: bool, experiment_suite: carla.driving_benchmark.experiment_suites.experiment_suite.ExperimentSuite, config: str, episode_max_time: int, allow_braking: bool, quality: rl_coach.environments.carla_environment.CarlaEnvironmentParameters.Quality, cameras: List[rl_coach.environments.carla_environment.CameraTypes], weather_id: List[int], experiment_path: str, separate_actions_for_throttle_and_brake: bool, num_speedup_steps: int, max_speed: float, target_success_rate: float = 1.0, **kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/carla_environment.html#CarlaEnvironment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.carla_environment.CarlaEnvironment" title="Permalink to this definition"></a></dt>
<dd></dd></dl>
</div>
@@ -511,7 +511,7 @@ includes a set of robotics environments.</p></li>
</ul>
<dl class="class">
<dt id="rl_coach.environments.gym_environment.GymEnvironment">
<em class="property">class </em><code class="descclassname">rl_coach.environments.gym_environment.</code><code class="descname">GymEnvironment</code><span class="sig-paren">(</span><em>level: rl_coach.environments.environment.LevelSelection</em>, <em>frame_skip: int</em>, <em>visualization_parameters: rl_coach.base_parameters.VisualizationParameters</em>, <em>target_success_rate: float = 1.0</em>, <em>additional_simulator_parameters: Dict[str</em>, <em>Any] = {}</em>, <em>seed: Union[None</em>, <em>int] = None</em>, <em>human_control: bool = False</em>, <em>custom_reward_threshold: Union[int</em>, <em>float] = None</em>, <em>random_initialization_steps: int = 1</em>, <em>max_over_num_frames: int = 1</em>, <em>observation_space_type: rl_coach.environments.gym_environment.ObservationSpaceType = None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/gym_environment.html#GymEnvironment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.gym_environment.GymEnvironment" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.environments.gym_environment.</code><code class="sig-name descname">GymEnvironment</code><span class="sig-paren">(</span><em class="sig-param">level: rl_coach.environments.environment.LevelSelection</em>, <em class="sig-param">frame_skip: int</em>, <em class="sig-param">visualization_parameters: rl_coach.base_parameters.VisualizationParameters</em>, <em class="sig-param">target_success_rate: float = 1.0</em>, <em class="sig-param">additional_simulator_parameters: Dict[str</em>, <em class="sig-param">Any] = {}</em>, <em class="sig-param">seed: Union[None</em>, <em class="sig-param">int] = None</em>, <em class="sig-param">human_control: bool = False</em>, <em class="sig-param">custom_reward_threshold: Union[int</em>, <em class="sig-param">float] = None</em>, <em class="sig-param">random_initialization_steps: int = 1</em>, <em class="sig-param">max_over_num_frames: int = 1</em>, <em class="sig-param">observation_space_type: rl_coach.environments.gym_environment.ObservationSpaceType = None</em>, <em class="sig-param">**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/environments/gym_environment.html#GymEnvironment"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.environments.gym_environment.GymEnvironment" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
+33 -27
View File
@@ -205,7 +205,7 @@ predefined policy. This is one of the most important aspects of reinforcement le
tuning to get it right. Coach supports several pre-defined exploration policies, and it can be easily extended with
custom policies. Note that not all exploration policies are expected to work for both discrete and continuous action
spaces.</p>
<table class="docutils align-center">
<table class="docutils align-default">
<colgroup>
<col style="width: 35%" />
<col style="width: 37%" />
@@ -268,7 +268,7 @@ spaces.</p>
<h2>ExplorationPolicy<a class="headerlink" href="#explorationpolicy" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.exploration_policy.ExplorationPolicy">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.exploration_policy.</code><code class="descname">ExplorationPolicy</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.exploration_policy.</code><code class="sig-name descname">ExplorationPolicy</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy" title="Permalink to this definition"></a></dt>
<dd><p>An exploration policy takes the predicted actions or action values from the agent, and selects the action to
actually apply to the environment using some predefined algorithm.</p>
<dl class="field-list simple">
@@ -278,7 +278,7 @@ actually apply to the environment using some predefined algorithm.</p>
</dl>
<dl class="method">
<dt id="rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.change_phase">
<code class="descname">change_phase</code><span class="sig-paren">(</span><em>phase</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy.change_phase"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.change_phase" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">change_phase</code><span class="sig-paren">(</span><em class="sig-param">phase</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy.change_phase"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.change_phase" title="Permalink to this definition"></a></dt>
<dd><p>Change between running phases of the algorithm
:param phase: Either Heatup or Train
:return: none</p>
@@ -286,16 +286,19 @@ actually apply to the environment using some predefined algorithm.</p>
<dl class="method">
<dt id="rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.get_action">
<code class="descname">get_action</code><span class="sig-paren">(</span><em>action_values: List[Union[int, float, numpy.ndarray, List]]</em><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy.get_action"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.get_action" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_action</code><span class="sig-paren">(</span><em class="sig-param">action_values: List[Union[int, float, numpy.ndarray, List]]</em><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy.get_action"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.get_action" title="Permalink to this definition"></a></dt>
<dd><p>Given a list of values corresponding to each action,
choose one actions according to the exploration policy
:param action_values: A list of action values
:return: The chosen action</p>
:return: The chosen action,</p>
<blockquote>
<div><p>The probability of the action (if available, otherwise 1 for absolute certainty in the action)</p>
</div></blockquote>
</dd></dl>
<dl class="method">
<dt id="rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.requires_action_values">
<code class="descname">requires_action_values</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy.requires_action_values"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.requires_action_values" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">requires_action_values</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy.requires_action_values"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.requires_action_values" title="Permalink to this definition"></a></dt>
<dd><p>Allows exploration policies to define if they require the action values for the current step.
This can save up a lot of computation. For example in e-greedy, if the random value generated is smaller
than epsilon, the action is completely random, and the action values dont need to be calculated
@@ -304,7 +307,7 @@ than epsilon, the action is completely random, and the action values dont nee
<dl class="method">
<dt id="rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.reset">
<code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy.reset"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.reset" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/exploration_policy.html#ExplorationPolicy.reset"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.exploration_policy.ExplorationPolicy.reset" title="Permalink to this definition"></a></dt>
<dd><p>Used for resetting the exploration policy parameters when needed
:return: None</p>
</dd></dl>
@@ -316,7 +319,7 @@ than epsilon, the action is completely random, and the action values dont nee
<h2>AdditiveNoise<a class="headerlink" href="#additivenoise" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.additive_noise.AdditiveNoise">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.additive_noise.</code><code class="descname">AdditiveNoise</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em>, <em>noise_percentage_schedule: rl_coach.schedules.Schedule</em>, <em>evaluation_noise_percentage: float</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/additive_noise.html#AdditiveNoise"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.additive_noise.AdditiveNoise" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.additive_noise.</code><code class="sig-name descname">AdditiveNoise</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em>, <em class="sig-param">noise_schedule: rl_coach.schedules.Schedule</em>, <em class="sig-param">evaluation_noise: float</em>, <em class="sig-param">noise_as_percentage_from_action_space: bool = True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/additive_noise.html#AdditiveNoise"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.additive_noise.AdditiveNoise" title="Permalink to this definition"></a></dt>
<dd><p>AdditiveNoise is an exploration policy intended for continuous action spaces. It takes the action from the agent
and adds a Gaussian distributed noise to it. The amount of noise added to the action follows the noise amount that
can be given in two different ways:
@@ -327,9 +330,10 @@ be the mean of the action, and 2nd is assumed to be its standard deviation.</p>
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>action_space</strong> the action space used by the environment</p></li>
<li><p><strong>noise_percentage_schedule</strong> the schedule for the noise variance percentage relative to the absolute range
of the action space</p></li>
<li><p><strong>evaluation_noise_percentage</strong> the noise variance percentage that will be used during evaluation phases</p></li>
<li><p><strong>noise_schedule</strong> the schedule for the noise</p></li>
<li><p><strong>evaluation_noise</strong> the noise variance that will be used during evaluation phases</p></li>
<li><p><strong>noise_as_percentage_from_action_space</strong> a bool deciding whether the noise is absolute or as a percentage
from the action space</p></li>
</ul>
</dd>
</dl>
@@ -340,7 +344,7 @@ of the action space</p></li>
<h2>Boltzmann<a class="headerlink" href="#boltzmann" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.boltzmann.Boltzmann">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.boltzmann.</code><code class="descname">Boltzmann</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em>, <em>temperature_schedule: rl_coach.schedules.Schedule</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/boltzmann.html#Boltzmann"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.boltzmann.Boltzmann" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.boltzmann.</code><code class="sig-name descname">Boltzmann</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em>, <em class="sig-param">temperature_schedule: rl_coach.schedules.Schedule</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/boltzmann.html#Boltzmann"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.boltzmann.Boltzmann" title="Permalink to this definition"></a></dt>
<dd><p>The Boltzmann exploration policy is intended for discrete action spaces. It assumes that each of the possible
actions has some value assigned to it (such as the Q value), and uses a softmax function to convert these values
into a distribution over the actions. It then samples the action for playing out of the calculated distribution.
@@ -360,7 +364,7 @@ An additional temperature schedule can be given by the user, and will control th
<h2>Bootstrapped<a class="headerlink" href="#bootstrapped" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.bootstrapped.Bootstrapped">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.bootstrapped.</code><code class="descname">Bootstrapped</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em>, <em>epsilon_schedule: rl_coach.schedules.Schedule</em>, <em>evaluation_epsilon: float</em>, <em>architecture_num_q_heads: int</em>, <em>continuous_exploration_policy_parameters: rl_coach.exploration_policies.exploration_policy.ExplorationParameters = &lt;rl_coach.exploration_policies.additive_noise.AdditiveNoiseParameters object&gt;</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/bootstrapped.html#Bootstrapped"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.bootstrapped.Bootstrapped" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.bootstrapped.</code><code class="sig-name descname">Bootstrapped</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em>, <em class="sig-param">epsilon_schedule: rl_coach.schedules.Schedule</em>, <em class="sig-param">evaluation_epsilon: float</em>, <em class="sig-param">architecture_num_q_heads: int</em>, <em class="sig-param">continuous_exploration_policy_parameters: rl_coach.exploration_policies.exploration_policy.ExplorationParameters = &lt;rl_coach.exploration_policies.additive_noise.AdditiveNoiseParameters object&gt;</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/bootstrapped.html#Bootstrapped"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.bootstrapped.Bootstrapped" title="Permalink to this definition"></a></dt>
<dd><p>Bootstrapped exploration policy is currently only used for discrete action spaces along with the
Bootstrapped DQN agent. It assumes that there is an ensemble of network heads, where each one predicts the
values for all the possible actions. For each episode, a single head is selected to lead the agent, according
@@ -390,7 +394,7 @@ if the e-greedy is used for a continuous policy</p></li>
<h2>Categorical<a class="headerlink" href="#categorical" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.categorical.Categorical">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.categorical.</code><code class="descname">Categorical</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/categorical.html#Categorical"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.categorical.Categorical" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.categorical.</code><code class="sig-name descname">Categorical</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/categorical.html#Categorical"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.categorical.Categorical" title="Permalink to this definition"></a></dt>
<dd><p>Categorical exploration policy is intended for discrete action spaces. It expects the action values to
represent a probability distribution over the action, from which a single action will be sampled.
In evaluation, the action that has the highest probability will be selected. This is particularly useful for
@@ -407,7 +411,7 @@ actor-critic schemes, where the actors output is a probability distribution over
<h2>ContinuousEntropy<a class="headerlink" href="#continuousentropy" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.continuous_entropy.ContinuousEntropy">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.continuous_entropy.</code><code class="descname">ContinuousEntropy</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em>, <em>noise_percentage_schedule: rl_coach.schedules.Schedule</em>, <em>evaluation_noise_percentage: float</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/continuous_entropy.html#ContinuousEntropy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.continuous_entropy.ContinuousEntropy" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.continuous_entropy.</code><code class="sig-name descname">ContinuousEntropy</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em>, <em class="sig-param">noise_schedule: rl_coach.schedules.Schedule</em>, <em class="sig-param">evaluation_noise: float</em>, <em class="sig-param">noise_as_percentage_from_action_space: bool = True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/continuous_entropy.html#ContinuousEntropy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.continuous_entropy.ContinuousEntropy" title="Permalink to this definition"></a></dt>
<dd><p>Continuous entropy is an exploration policy that is actually implemented as part of the network.
The exploration policy class is only a placeholder for choosing this policy. The exploration policy is
implemented by adding a regularization factor to the network loss, which regularizes the entropy of the action.
@@ -422,9 +426,10 @@ There are only a few heads that actually are relevant and implement the entropy
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>action_space</strong> the action space used by the environment</p></li>
<li><p><strong>noise_percentage_schedule</strong> the schedule for the noise variance percentage relative to the absolute range
of the action space</p></li>
<li><p><strong>evaluation_noise_percentage</strong> the noise variance percentage that will be used during evaluation phases</p></li>
<li><p><strong>noise_schedule</strong> the schedule for the noise</p></li>
<li><p><strong>evaluation_noise</strong> the noise variance that will be used during evaluation phases</p></li>
<li><p><strong>noise_as_percentage_from_action_space</strong> a bool deciding whether the noise is absolute or as a percentage
from the action space</p></li>
</ul>
</dd>
</dl>
@@ -435,7 +440,7 @@ of the action space</p></li>
<h2>EGreedy<a class="headerlink" href="#egreedy" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.e_greedy.EGreedy">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.e_greedy.</code><code class="descname">EGreedy</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em>, <em>epsilon_schedule: rl_coach.schedules.Schedule</em>, <em>evaluation_epsilon: float</em>, <em>continuous_exploration_policy_parameters: rl_coach.exploration_policies.exploration_policy.ExplorationParameters = &lt;rl_coach.exploration_policies.additive_noise.AdditiveNoiseParameters object&gt;</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/e_greedy.html#EGreedy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.e_greedy.EGreedy" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.e_greedy.</code><code class="sig-name descname">EGreedy</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em>, <em class="sig-param">epsilon_schedule: rl_coach.schedules.Schedule</em>, <em class="sig-param">evaluation_epsilon: float</em>, <em class="sig-param">continuous_exploration_policy_parameters: rl_coach.exploration_policies.exploration_policy.ExplorationParameters = &lt;rl_coach.exploration_policies.additive_noise.AdditiveNoiseParameters object&gt;</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/e_greedy.html#EGreedy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.e_greedy.EGreedy" title="Permalink to this definition"></a></dt>
<dd><p>e-greedy is an exploration policy that is intended for both discrete and continuous action spaces.</p>
<p>For discrete action spaces, it assumes that each action is assigned a value, and it selects the action with the
highest value with probability 1 - epsilon. Otherwise, it selects a action sampled uniformly out of all the
@@ -463,7 +468,7 @@ if the e-greedy is used for a continuous policy</p></li>
<h2>Greedy<a class="headerlink" href="#greedy" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.greedy.Greedy">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.greedy.</code><code class="descname">Greedy</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/greedy.html#Greedy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.greedy.Greedy" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.greedy.</code><code class="sig-name descname">Greedy</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/greedy.html#Greedy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.greedy.Greedy" title="Permalink to this definition"></a></dt>
<dd><p>The Greedy exploration policy is intended for both discrete and continuous action spaces.
For discrete action spaces, it always selects the action with the maximum value, as given by the agent.
For continuous action spaces, it always return the exact action, as it was given by the agent.</p>
@@ -479,7 +484,7 @@ For continuous action spaces, it always return the exact action, as it was given
<h2>OUProcess<a class="headerlink" href="#ouprocess" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.ou_process.OUProcess">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.ou_process.</code><code class="descname">OUProcess</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em>, <em>mu: float = 0</em>, <em>theta: float = 0.15</em>, <em>sigma: float = 0.2</em>, <em>dt: float = 0.01</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/ou_process.html#OUProcess"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.ou_process.OUProcess" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.ou_process.</code><code class="sig-name descname">OUProcess</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em>, <em class="sig-param">mu: float = 0</em>, <em class="sig-param">theta: float = 0.15</em>, <em class="sig-param">sigma: float = 0.2</em>, <em class="sig-param">dt: float = 0.01</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/ou_process.html#OUProcess"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.ou_process.OUProcess" title="Permalink to this definition"></a></dt>
<dd><p>OUProcess exploration policy is intended for continuous action spaces, and selects the action according to
an Ornstein-Uhlenbeck process. The Ornstein-Uhlenbeck process implements the action as a Gaussian process, where
the samples are correlated between consequent time steps.</p>
@@ -495,7 +500,7 @@ the samples are correlated between consequent time steps.</p>
<h2>ParameterNoise<a class="headerlink" href="#parameternoise" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.parameter_noise.ParameterNoise">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.parameter_noise.</code><code class="descname">ParameterNoise</code><span class="sig-paren">(</span><em>network_params: Dict[str, rl_coach.base_parameters.NetworkParameters], action_space: rl_coach.spaces.ActionSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/parameter_noise.html#ParameterNoise"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.parameter_noise.ParameterNoise" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.parameter_noise.</code><code class="sig-name descname">ParameterNoise</code><span class="sig-paren">(</span><em class="sig-param">network_params: Dict[str, rl_coach.base_parameters.NetworkParameters], action_space: rl_coach.spaces.ActionSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/parameter_noise.html#ParameterNoise"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.parameter_noise.ParameterNoise" title="Permalink to this definition"></a></dt>
<dd><p>The ParameterNoise exploration policy is intended for both discrete and continuous action spaces.
It applies the exploration policy by replacing all the dense network layers with noisy layers.
The noisy layers have both weight means and weight standard deviations, and for each forward pass of the network
@@ -514,7 +519,7 @@ values.</p>
<h2>TruncatedNormal<a class="headerlink" href="#truncatednormal" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.truncated_normal.TruncatedNormal">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.truncated_normal.</code><code class="descname">TruncatedNormal</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em>, <em>noise_percentage_schedule: rl_coach.schedules.Schedule</em>, <em>evaluation_noise_percentage: float</em>, <em>clip_low: float</em>, <em>clip_high: float</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/truncated_normal.html#TruncatedNormal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.truncated_normal.TruncatedNormal" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.truncated_normal.</code><code class="sig-name descname">TruncatedNormal</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em>, <em class="sig-param">noise_schedule: rl_coach.schedules.Schedule</em>, <em class="sig-param">evaluation_noise: float</em>, <em class="sig-param">clip_low: float</em>, <em class="sig-param">clip_high: float</em>, <em class="sig-param">noise_as_percentage_from_action_space: bool = True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/truncated_normal.html#TruncatedNormal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.truncated_normal.TruncatedNormal" title="Permalink to this definition"></a></dt>
<dd><p>The TruncatedNormal exploration policy is intended for continuous action spaces. It samples the action from a
normal distribution, where the mean action is given by the agent, and the standard deviation can be given in t
wo different ways:
@@ -527,9 +532,10 @@ is within the bounds.</p>
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>action_space</strong> the action space used by the environment</p></li>
<li><p><strong>noise_percentage_schedule</strong> the schedule for the noise variance percentage relative to the absolute range
of the action space</p></li>
<li><p><strong>evaluation_noise_percentage</strong> the noise variance percentage that will be used during evaluation phases</p></li>
<li><p><strong>noise_schedule</strong> the schedule for the noise variance</p></li>
<li><p><strong>evaluation_noise</strong> the noise variance that will be used during evaluation phases</p></li>
<li><p><strong>noise_as_percentage_from_action_space</strong> whether to consider the noise as a percentage of the action space
or absolute value</p></li>
</ul>
</dd>
</dl>
@@ -540,7 +546,7 @@ of the action space</p></li>
<h2>UCB<a class="headerlink" href="#ucb" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.exploration_policies.ucb.UCB">
<em class="property">class </em><code class="descclassname">rl_coach.exploration_policies.ucb.</code><code class="descname">UCB</code><span class="sig-paren">(</span><em>action_space: rl_coach.spaces.ActionSpace</em>, <em>epsilon_schedule: rl_coach.schedules.Schedule</em>, <em>evaluation_epsilon: float</em>, <em>architecture_num_q_heads: int</em>, <em>lamb: int</em>, <em>continuous_exploration_policy_parameters: rl_coach.exploration_policies.exploration_policy.ExplorationParameters = &lt;rl_coach.exploration_policies.additive_noise.AdditiveNoiseParameters object&gt;</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/ucb.html#UCB"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.ucb.UCB" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.exploration_policies.ucb.</code><code class="sig-name descname">UCB</code><span class="sig-paren">(</span><em class="sig-param">action_space: rl_coach.spaces.ActionSpace</em>, <em class="sig-param">epsilon_schedule: rl_coach.schedules.Schedule</em>, <em class="sig-param">evaluation_epsilon: float</em>, <em class="sig-param">architecture_num_q_heads: int</em>, <em class="sig-param">lamb: int</em>, <em class="sig-param">continuous_exploration_policy_parameters: rl_coach.exploration_policies.exploration_policy.ExplorationParameters = &lt;rl_coach.exploration_policies.additive_noise.AdditiveNoiseParameters object&gt;</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/exploration_policies/ucb.html#UCB"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.exploration_policies.ucb.UCB" title="Permalink to this definition"></a></dt>
<dd><p>UCB exploration policy is following the upper confidence bound heuristic to sample actions in discrete action spaces.
It assumes that there are multiple network heads that are predicting action values, and that the standard deviation
between the heads predictions represents the uncertainty of the agent in each of the actions.
+14 -14
View File
@@ -221,7 +221,7 @@
<h3>ObservationClippingFilter<a class="headerlink" href="#observationclippingfilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationClippingFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationClippingFilter</code><span class="sig-paren">(</span><em>clipping_low: float = -inf</em>, <em>clipping_high: float = inf</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_clipping_filter.html#ObservationClippingFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationClippingFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationClippingFilter</code><span class="sig-paren">(</span><em class="sig-param">clipping_low: float = -inf</em>, <em class="sig-param">clipping_high: float = inf</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_clipping_filter.html#ObservationClippingFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationClippingFilter" title="Permalink to this definition"></a></dt>
<dd><p>Clips the observation values to a given range of values.
For example, if the observation consists of measurements in an arbitrary range,
and we want to control the minimum and maximum values of these observations,
@@ -241,7 +241,7 @@ we can define a range and clip the values of the measurements.</p>
<h3>ObservationCropFilter<a class="headerlink" href="#observationcropfilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationCropFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationCropFilter</code><span class="sig-paren">(</span><em>crop_low: numpy.ndarray = None</em>, <em>crop_high: numpy.ndarray = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_crop_filter.html#ObservationCropFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationCropFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationCropFilter</code><span class="sig-paren">(</span><em class="sig-param">crop_low: numpy.ndarray = None</em>, <em class="sig-param">crop_high: numpy.ndarray = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_crop_filter.html#ObservationCropFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationCropFilter" title="Permalink to this definition"></a></dt>
<dd><p>Crops the size of the observation to a given crop window. For example, in Atari, the
observations are images with a shape of 210x160. Usually, we will want to crop the size of the observation to a
square of 160x160 before rescaling them.</p>
@@ -262,7 +262,7 @@ corresponding dimension. a negative value of -1 will be mapped to the max size</
<h3>ObservationMoveAxisFilter<a class="headerlink" href="#observationmoveaxisfilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationMoveAxisFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationMoveAxisFilter</code><span class="sig-paren">(</span><em>axis_origin: int = None</em>, <em>axis_target: int = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_move_axis_filter.html#ObservationMoveAxisFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationMoveAxisFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationMoveAxisFilter</code><span class="sig-paren">(</span><em class="sig-param">axis_origin: int = None</em>, <em class="sig-param">axis_target: int = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_move_axis_filter.html#ObservationMoveAxisFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationMoveAxisFilter" title="Permalink to this definition"></a></dt>
<dd><p>Reorders the axes of the observation. This can be useful when the observation is an
image, and we want to move the channel axis to be the last axis instead of the first axis.</p>
<dl class="field-list simple">
@@ -280,7 +280,7 @@ image, and we want to move the channel axis to be the last axis instead of the f
<h3>ObservationNormalizationFilter<a class="headerlink" href="#observationnormalizationfilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationNormalizationFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationNormalizationFilter</code><span class="sig-paren">(</span><em>clip_min: float = -5.0</em>, <em>clip_max: float = 5.0</em>, <em>name='observation_stats'</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_normalization_filter.html#ObservationNormalizationFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationNormalizationFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationNormalizationFilter</code><span class="sig-paren">(</span><em class="sig-param">clip_min: float = -5.0</em>, <em class="sig-param">clip_max: float = 5.0</em>, <em class="sig-param">name='observation_stats'</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_normalization_filter.html#ObservationNormalizationFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationNormalizationFilter" title="Permalink to this definition"></a></dt>
<dd><p>Normalizes the observation values with a running mean and standard deviation of
all the observations seen so far. The normalization is performed element-wise. Additionally, when working with
multiple workers, the statistics used for the normalization operation are accumulated over all the workers.</p>
@@ -299,7 +299,7 @@ multiple workers, the statistics used for the normalization operation are accumu
<h3>ObservationReductionBySubPartsNameFilter<a class="headerlink" href="#observationreductionbysubpartsnamefilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationReductionBySubPartsNameFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationReductionBySubPartsNameFilter</code><span class="sig-paren">(</span><em>part_names: List[str], reduction_method: rl_coach.filters.observation.observation_reduction_by_sub_parts_name_filter.ObservationReductionBySubPartsNameFilter.ReductionMethod</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_reduction_by_sub_parts_name_filter.html#ObservationReductionBySubPartsNameFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationReductionBySubPartsNameFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationReductionBySubPartsNameFilter</code><span class="sig-paren">(</span><em class="sig-param">part_names: List[str], reduction_method: rl_coach.filters.observation.observation_reduction_by_sub_parts_name_filter.ObservationReductionBySubPartsNameFilter.ReductionMethod</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_reduction_by_sub_parts_name_filter.html#ObservationReductionBySubPartsNameFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationReductionBySubPartsNameFilter" title="Permalink to this definition"></a></dt>
<dd><p>Allows keeping only parts of the observation, by specifying their
name. This is useful when the environment has a measurements vector as observation which includes several different
measurements, but you want the agent to only see some of the measurements and not all.
@@ -321,7 +321,7 @@ This will currently work only for VectorObservationSpace observations</p>
<h3>ObservationRescaleSizeByFactorFilter<a class="headerlink" href="#observationrescalesizebyfactorfilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationRescaleSizeByFactorFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationRescaleSizeByFactorFilter</code><span class="sig-paren">(</span><em>rescale_factor: float</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_rescale_size_by_factor_filter.html#ObservationRescaleSizeByFactorFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationRescaleSizeByFactorFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationRescaleSizeByFactorFilter</code><span class="sig-paren">(</span><em class="sig-param">rescale_factor: float</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_rescale_size_by_factor_filter.html#ObservationRescaleSizeByFactorFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationRescaleSizeByFactorFilter" title="Permalink to this definition"></a></dt>
<dd><p>Rescales an image observation by some factor. For example, the image size
can be reduced by a factor of 2.</p>
<dl class="field-list simple">
@@ -336,7 +336,7 @@ can be reduced by a factor of 2.</p>
<h3>ObservationRescaleToSizeFilter<a class="headerlink" href="#observationrescaletosizefilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationRescaleToSizeFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationRescaleToSizeFilter</code><span class="sig-paren">(</span><em>output_observation_space: rl_coach.spaces.PlanarMapsObservationSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_rescale_to_size_filter.html#ObservationRescaleToSizeFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationRescaleToSizeFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationRescaleToSizeFilter</code><span class="sig-paren">(</span><em class="sig-param">output_observation_space: rl_coach.spaces.PlanarMapsObservationSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_rescale_to_size_filter.html#ObservationRescaleToSizeFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationRescaleToSizeFilter" title="Permalink to this definition"></a></dt>
<dd><p>Rescales an image observation to a given size. The target size does not
necessarily keep the aspect ratio of the original observation.
Warning: this requires the input observation to be of type uint8 due to scipy requirements!</p>
@@ -352,7 +352,7 @@ Warning: this requires the input observation to be of type uint8 due to scipy re
<h3>ObservationRGBToYFilter<a class="headerlink" href="#observationrgbtoyfilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationRGBToYFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationRGBToYFilter</code><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_rgb_to_y_filter.html#ObservationRGBToYFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationRGBToYFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationRGBToYFilter</code><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_rgb_to_y_filter.html#ObservationRGBToYFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationRGBToYFilter" title="Permalink to this definition"></a></dt>
<dd><p>Converts a color image observation specified using the RGB encoding into a grayscale
image observation, by keeping only the luminance (Y) channel of the YUV encoding. This can be useful if the colors
in the original image are not relevant for solving the task at hand.
@@ -364,7 +364,7 @@ The channels axis is assumed to be the last axis</p>
<h3>ObservationSqueezeFilter<a class="headerlink" href="#observationsqueezefilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationSqueezeFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationSqueezeFilter</code><span class="sig-paren">(</span><em>axis: int = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_squeeze_filter.html#ObservationSqueezeFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationSqueezeFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationSqueezeFilter</code><span class="sig-paren">(</span><em class="sig-param">axis: int = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_squeeze_filter.html#ObservationSqueezeFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationSqueezeFilter" title="Permalink to this definition"></a></dt>
<dd><p>Removes redundant axes from the observation, which are axes with a dimension of 1.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -378,7 +378,7 @@ The channels axis is assumed to be the last axis</p>
<h3>ObservationStackingFilter<a class="headerlink" href="#observationstackingfilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationStackingFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationStackingFilter</code><span class="sig-paren">(</span><em>stack_size: int</em>, <em>stacking_axis: int = -1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_stacking_filter.html#ObservationStackingFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationStackingFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationStackingFilter</code><span class="sig-paren">(</span><em class="sig-param">stack_size: int</em>, <em class="sig-param">stacking_axis: int = -1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_stacking_filter.html#ObservationStackingFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationStackingFilter" title="Permalink to this definition"></a></dt>
<dd><p>Stacks several observations on top of each other. For image observation this will
create a 3D blob. The stacking is done in a lazy manner in order to reduce memory consumption. To achieve this,
a LazyStack object is used in order to wrap the observations in the stack. For this reason, the
@@ -403,7 +403,7 @@ and increase the memory footprint.</p>
<h3>ObservationToUInt8Filter<a class="headerlink" href="#observationtouint8filter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.observation.ObservationToUInt8Filter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.observation.</code><code class="descname">ObservationToUInt8Filter</code><span class="sig-paren">(</span><em>input_low: float</em>, <em>input_high: float</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_to_uint8_filter.html#ObservationToUInt8Filter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationToUInt8Filter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.observation.</code><code class="sig-name descname">ObservationToUInt8Filter</code><span class="sig-paren">(</span><em class="sig-param">input_low: float</em>, <em class="sig-param">input_high: float</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/observation/observation_to_uint8_filter.html#ObservationToUInt8Filter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.observation.ObservationToUInt8Filter" title="Permalink to this definition"></a></dt>
<dd><p>Converts a floating point observation into an unsigned int 8 bit observation. This is
mostly useful for reducing memory consumption and is usually used for image observations. The filter will first
spread the observation values over the range 0-255 and then discretize them into integer values.</p>
@@ -425,7 +425,7 @@ spread the observation values over the range 0-255 and then discretize them into
<h3>RewardClippingFilter<a class="headerlink" href="#rewardclippingfilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.reward.RewardClippingFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.reward.</code><code class="descname">RewardClippingFilter</code><span class="sig-paren">(</span><em>clipping_low: float = -inf</em>, <em>clipping_high: float = inf</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/reward/reward_clipping_filter.html#RewardClippingFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.reward.RewardClippingFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.reward.</code><code class="sig-name descname">RewardClippingFilter</code><span class="sig-paren">(</span><em class="sig-param">clipping_low: float = -inf</em>, <em class="sig-param">clipping_high: float = inf</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/reward/reward_clipping_filter.html#RewardClippingFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.reward.RewardClippingFilter" title="Permalink to this definition"></a></dt>
<dd><p>Clips the reward values into a given range. For example, in DQN, the Atari rewards are
clipped into the range -1 and 1 in order to control the scale of the returns.</p>
<dl class="field-list simple">
@@ -443,7 +443,7 @@ clipped into the range -1 and 1 in order to control the scale of the returns.</p
<h3>RewardNormalizationFilter<a class="headerlink" href="#rewardnormalizationfilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.reward.RewardNormalizationFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.reward.</code><code class="descname">RewardNormalizationFilter</code><span class="sig-paren">(</span><em>clip_min: float = -5.0</em>, <em>clip_max: float = 5.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/reward/reward_normalization_filter.html#RewardNormalizationFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.reward.RewardNormalizationFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.reward.</code><code class="sig-name descname">RewardNormalizationFilter</code><span class="sig-paren">(</span><em class="sig-param">clip_min: float = -5.0</em>, <em class="sig-param">clip_max: float = 5.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/reward/reward_normalization_filter.html#RewardNormalizationFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.reward.RewardNormalizationFilter" title="Permalink to this definition"></a></dt>
<dd><p>Normalizes the reward values with a running mean and standard deviation of
all the rewards seen so far. When working with multiple workers, the statistics used for the normalization operation
are accumulated over all the workers.</p>
@@ -462,7 +462,7 @@ are accumulated over all the workers.</p>
<h3>RewardRescaleFilter<a class="headerlink" href="#rewardrescalefilter" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.filters.reward.RewardRescaleFilter">
<em class="property">class </em><code class="descclassname">rl_coach.filters.reward.</code><code class="descname">RewardRescaleFilter</code><span class="sig-paren">(</span><em>rescale_factor: float</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/reward/reward_rescale_filter.html#RewardRescaleFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.reward.RewardRescaleFilter" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.reward.</code><code class="sig-name descname">RewardRescaleFilter</code><span class="sig-paren">(</span><em class="sig-param">rescale_factor: float</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/reward/reward_rescale_filter.html#RewardRescaleFilter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.reward.RewardRescaleFilter" title="Permalink to this definition"></a></dt>
<dd><p>Rescales the reward by a given factor. Rescaling the rewards of the environment has been
observed to have a large effect (negative or positive) on the behavior of the learning process.</p>
<dl class="field-list simple">
+6 -6
View File
@@ -200,7 +200,7 @@
<h2>Action Filters<a class="headerlink" href="#action-filters" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.filters.action.AttentionDiscretization">
<em class="property">class </em><code class="descclassname">rl_coach.filters.action.</code><code class="descname">AttentionDiscretization</code><span class="sig-paren">(</span><em>num_bins_per_dimension: Union[int, List[int]], force_int_bins=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/attention_discretization.html#AttentionDiscretization"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.AttentionDiscretization" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.action.</code><code class="sig-name descname">AttentionDiscretization</code><span class="sig-paren">(</span><em class="sig-param">num_bins_per_dimension: Union[int, List[int]], force_int_bins=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/attention_discretization.html#AttentionDiscretization"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.AttentionDiscretization" title="Permalink to this definition"></a></dt>
<dd><p>Discretizes an <strong>AttentionActionSpace</strong>. The attention action space defines the actions
as choosing sub-boxes in a given box. For example, consider an image of size 100x100, where the action is choosing
a crop window of size 20x20 to attend to in the image. AttentionDiscretization allows discretizing the possible crop
@@ -219,7 +219,7 @@ windows to choose into a finite number of options, and map a discrete action spa
<img alt="../../_images/attention_discretization.png" class="align-center" src="../../_images/attention_discretization.png" />
<dl class="class">
<dt id="rl_coach.filters.action.BoxDiscretization">
<em class="property">class </em><code class="descclassname">rl_coach.filters.action.</code><code class="descname">BoxDiscretization</code><span class="sig-paren">(</span><em>num_bins_per_dimension: Union[int, List[int]], force_int_bins=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/box_discretization.html#BoxDiscretization"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.BoxDiscretization" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.action.</code><code class="sig-name descname">BoxDiscretization</code><span class="sig-paren">(</span><em class="sig-param">num_bins_per_dimension: Union[int, List[int]], force_int_bins=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/box_discretization.html#BoxDiscretization"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.BoxDiscretization" title="Permalink to this definition"></a></dt>
<dd><p>Discretizes a continuous action space into a discrete action space, allowing the usage of
agents such as DQN for continuous environments such as MuJoCo. Given the number of bins to discretize into, the
original continuous action space is uniformly separated into the given number of bins, each mapped to a discrete
@@ -242,7 +242,7 @@ instead of 0, 2.5, 5, 7.5, 10.</p></li>
<img alt="../../_images/box_discretization.png" class="align-center" src="../../_images/box_discretization.png" />
<dl class="class">
<dt id="rl_coach.filters.action.BoxMasking">
<em class="property">class </em><code class="descclassname">rl_coach.filters.action.</code><code class="descname">BoxMasking</code><span class="sig-paren">(</span><em>masked_target_space_low: Union[None, int, float, numpy.ndarray], masked_target_space_high: Union[None, int, float, numpy.ndarray]</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/box_masking.html#BoxMasking"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.BoxMasking" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.action.</code><code class="sig-name descname">BoxMasking</code><span class="sig-paren">(</span><em class="sig-param">masked_target_space_low: Union[None, int, float, numpy.ndarray], masked_target_space_high: Union[None, int, float, numpy.ndarray]</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/box_masking.html#BoxMasking"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.BoxMasking" title="Permalink to this definition"></a></dt>
<dd><p>Masks part of the action space to enforce the agent to work in a defined space. For example,
if the original action space is between -1 and 1, then this filter can be used in order to constrain the agent actions
to the range 0 and 1 instead. This essentially masks the range -1 and 0 from the agent.
@@ -260,7 +260,7 @@ The resulting action space will be shifted and will always start from 0 and have
<img alt="../../_images/box_masking.png" class="align-center" src="../../_images/box_masking.png" />
<dl class="class">
<dt id="rl_coach.filters.action.PartialDiscreteActionSpaceMap">
<em class="property">class </em><code class="descclassname">rl_coach.filters.action.</code><code class="descname">PartialDiscreteActionSpaceMap</code><span class="sig-paren">(</span><em>target_actions: List[Union[int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List]] = None</em>, <em>descriptions: List[str] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/partial_discrete_action_space_map.html#PartialDiscreteActionSpaceMap"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.PartialDiscreteActionSpaceMap" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.action.</code><code class="sig-name descname">PartialDiscreteActionSpaceMap</code><span class="sig-paren">(</span><em class="sig-param">target_actions: List[Union[int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray</em>, <em class="sig-param">List]] = None</em>, <em class="sig-param">descriptions: List[str] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/partial_discrete_action_space_map.html#PartialDiscreteActionSpaceMap"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.PartialDiscreteActionSpaceMap" title="Permalink to this definition"></a></dt>
<dd><p>Partial map of two countable action spaces. For example, consider an environment
with a MultiSelect action space (select multiple actions at the same time, such as jump and go right), with 8 actual
MultiSelect actions. If we want the agent to be able to select only 5 of those actions by their index (0-4), we can
@@ -279,7 +279,7 @@ use regular discrete actions, and mask 3 of the actions from the agent.</p>
<img alt="../../_images/partial_discrete_action_space_map.png" class="align-center" src="../../_images/partial_discrete_action_space_map.png" />
<dl class="class">
<dt id="rl_coach.filters.action.FullDiscreteActionSpaceMap">
<em class="property">class </em><code class="descclassname">rl_coach.filters.action.</code><code class="descname">FullDiscreteActionSpaceMap</code><a class="reference internal" href="../../_modules/rl_coach/filters/action/full_discrete_action_space_map.html#FullDiscreteActionSpaceMap"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.FullDiscreteActionSpaceMap" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.action.</code><code class="sig-name descname">FullDiscreteActionSpaceMap</code><a class="reference internal" href="../../_modules/rl_coach/filters/action/full_discrete_action_space_map.html#FullDiscreteActionSpaceMap"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.FullDiscreteActionSpaceMap" title="Permalink to this definition"></a></dt>
<dd><p>Full map of two countable action spaces. This works in a similar way to the
PartialDiscreteActionSpaceMap, but maps the entire source action space into the entire target action space, without
masking any actions.
@@ -290,7 +290,7 @@ multiselect actions.</p>
<img alt="../../_images/full_discrete_action_space_map.png" class="align-center" src="../../_images/full_discrete_action_space_map.png" />
<dl class="class">
<dt id="rl_coach.filters.action.LinearBoxToBoxMap">
<em class="property">class </em><code class="descclassname">rl_coach.filters.action.</code><code class="descname">LinearBoxToBoxMap</code><span class="sig-paren">(</span><em>input_space_low: Union[None, int, float, numpy.ndarray], input_space_high: Union[None, int, float, numpy.ndarray]</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/linear_box_to_box_map.html#LinearBoxToBoxMap"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.LinearBoxToBoxMap" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.filters.action.</code><code class="sig-name descname">LinearBoxToBoxMap</code><span class="sig-paren">(</span><em class="sig-param">input_space_low: Union[None, int, float, numpy.ndarray], input_space_high: Union[None, int, float, numpy.ndarray]</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/filters/action/linear_box_to_box_map.html#LinearBoxToBoxMap"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.filters.action.LinearBoxToBoxMap" title="Permalink to this definition"></a></dt>
<dd><p>A linear mapping of two box action spaces. For example, if the action space of the
environment consists of continuous actions between 0 and 1, and we want the agent to choose actions between -1 and 1,
the LinearBoxToBoxMap can be used to map the range -1 and 1 to the range 0 and 1 in a linear way. This means that the
+9 -9
View File
@@ -209,7 +209,7 @@
<h3>EpisodicExperienceReplay<a class="headerlink" href="#episodicexperiencereplay" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.episodic.EpisodicExperienceReplay">
<em class="property">class </em><code class="descclassname">rl_coach.memories.episodic.</code><code class="descname">EpisodicExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity</em>, <em>int] = (&lt;MemoryGranularity.Transitions: 0&gt;</em>, <em>1000000)</em>, <em>n_step=-1</em>, <em>train_to_eval_ratio: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_experience_replay.html#EpisodicExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicExperienceReplay" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.episodic.</code><code class="sig-name descname">EpisodicExperienceReplay</code><span class="sig-paren">(</span><em class="sig-param">max_size: Tuple[rl_coach.memories.memory.MemoryGranularity</em>, <em class="sig-param">int] = (&lt;MemoryGranularity.Transitions: 0&gt;</em>, <em class="sig-param">1000000)</em>, <em class="sig-param">n_step=-1</em>, <em class="sig-param">train_to_eval_ratio: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_experience_replay.html#EpisodicExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicExperienceReplay" title="Permalink to this definition"></a></dt>
<dd><p>A replay buffer that stores episodes of transitions. The additional structure allows performing various
calculations of total return and other values that depend on the sequential behavior of the transitions
in the episode.</p>
@@ -225,7 +225,7 @@ in the episode.</p>
<h3>EpisodicHindsightExperienceReplay<a class="headerlink" href="#episodichindsightexperiencereplay" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.episodic.EpisodicHindsightExperienceReplay">
<em class="property">class </em><code class="descclassname">rl_coach.memories.episodic.</code><code class="descname">EpisodicHindsightExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], hindsight_transitions_per_regular_transition: int, hindsight_goal_selection_method: rl_coach.memories.episodic.episodic_hindsight_experience_replay.HindsightGoalSelectionMethod, goals_space: rl_coach.spaces.GoalsSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_hindsight_experience_replay.html#EpisodicHindsightExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicHindsightExperienceReplay" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.episodic.</code><code class="sig-name descname">EpisodicHindsightExperienceReplay</code><span class="sig-paren">(</span><em class="sig-param">max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], hindsight_transitions_per_regular_transition: int, hindsight_goal_selection_method: rl_coach.memories.episodic.episodic_hindsight_experience_replay.HindsightGoalSelectionMethod, goals_space: rl_coach.spaces.GoalsSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_hindsight_experience_replay.html#EpisodicHindsightExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicHindsightExperienceReplay" title="Permalink to this definition"></a></dt>
<dd><p>Implements Hindsight Experience Replay as described in the following paper: <a class="reference external" href="https://arxiv.org/pdf/1707.01495.pdf">https://arxiv.org/pdf/1707.01495.pdf</a></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -246,7 +246,7 @@ hindsight transitions. Should be one of HindsightGoalSelectionMethod</p></li>
<h3>EpisodicHRLHindsightExperienceReplay<a class="headerlink" href="#episodichrlhindsightexperiencereplay" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.episodic.EpisodicHRLHindsightExperienceReplay">
<em class="property">class </em><code class="descclassname">rl_coach.memories.episodic.</code><code class="descname">EpisodicHRLHindsightExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], hindsight_transitions_per_regular_transition: int, hindsight_goal_selection_method: rl_coach.memories.episodic.episodic_hindsight_experience_replay.HindsightGoalSelectionMethod, goals_space: rl_coach.spaces.GoalsSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_hrl_hindsight_experience_replay.html#EpisodicHRLHindsightExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicHRLHindsightExperienceReplay" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.episodic.</code><code class="sig-name descname">EpisodicHRLHindsightExperienceReplay</code><span class="sig-paren">(</span><em class="sig-param">max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], hindsight_transitions_per_regular_transition: int, hindsight_goal_selection_method: rl_coach.memories.episodic.episodic_hindsight_experience_replay.HindsightGoalSelectionMethod, goals_space: rl_coach.spaces.GoalsSpace</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_hrl_hindsight_experience_replay.html#EpisodicHRLHindsightExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicHRLHindsightExperienceReplay" title="Permalink to this definition"></a></dt>
<dd><p>Implements HRL Hindsight Experience Replay as described in the following paper: <a class="reference external" href="https://arxiv.org/abs/1805.08180">https://arxiv.org/abs/1805.08180</a></p>
<p>This is the memory you should use if you want a shared hindsight experience replay buffer between multiple workers</p>
<dl class="field-list simple">
@@ -269,7 +269,7 @@ hindsight transitions. Should be one of HindsightGoalSelectionMethod</p></li>
<h3>SingleEpisodeBuffer<a class="headerlink" href="#singleepisodebuffer" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.episodic.SingleEpisodeBuffer">
<em class="property">class </em><code class="descclassname">rl_coach.memories.episodic.</code><code class="descname">SingleEpisodeBuffer</code><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/single_episode_buffer.html#SingleEpisodeBuffer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.SingleEpisodeBuffer" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.episodic.</code><code class="sig-name descname">SingleEpisodeBuffer</code><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/single_episode_buffer.html#SingleEpisodeBuffer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.SingleEpisodeBuffer" title="Permalink to this definition"></a></dt>
<dd></dd></dl>
</div>
@@ -280,7 +280,7 @@ hindsight transitions. Should be one of HindsightGoalSelectionMethod</p></li>
<h3>BalancedExperienceReplay<a class="headerlink" href="#balancedexperiencereplay" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.non_episodic.BalancedExperienceReplay">
<em class="property">class </em><code class="descclassname">rl_coach.memories.non_episodic.</code><code class="descname">BalancedExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool = True, num_classes: int = 0, state_key_with_the_class_index: Any = 'class'</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/balanced_experience_replay.html#BalancedExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.BalancedExperienceReplay" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.non_episodic.</code><code class="sig-name descname">BalancedExperienceReplay</code><span class="sig-paren">(</span><em class="sig-param">max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool = True, num_classes: int = 0, state_key_with_the_class_index: Any = 'class'</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/balanced_experience_replay.html#BalancedExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.BalancedExperienceReplay" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
@@ -299,7 +299,7 @@ this parameter determines the key to retrieve the class index value</p></li>
<h3>QDND<a class="headerlink" href="#qdnd" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.non_episodic.QDND">
<em class="property">class </em><code class="descclassname">rl_coach.memories.non_episodic.</code><code class="descname">QDND</code><span class="sig-paren">(</span><em>dict_size</em>, <em>key_width</em>, <em>num_actions</em>, <em>new_value_shift_coefficient=0.1</em>, <em>key_error_threshold=0.01</em>, <em>learning_rate=0.01</em>, <em>num_neighbors=50</em>, <em>return_additional_data=False</em>, <em>override_existing_keys=False</em>, <em>rebuild_on_every_update=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/differentiable_neural_dictionary.html#QDND"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.QDND" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.non_episodic.</code><code class="sig-name descname">QDND</code><span class="sig-paren">(</span><em class="sig-param">dict_size</em>, <em class="sig-param">key_width</em>, <em class="sig-param">num_actions</em>, <em class="sig-param">new_value_shift_coefficient=0.1</em>, <em class="sig-param">key_error_threshold=0.01</em>, <em class="sig-param">learning_rate=0.01</em>, <em class="sig-param">num_neighbors=50</em>, <em class="sig-param">return_additional_data=False</em>, <em class="sig-param">override_existing_keys=False</em>, <em class="sig-param">rebuild_on_every_update=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/differentiable_neural_dictionary.html#QDND"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.QDND" title="Permalink to this definition"></a></dt>
<dd></dd></dl>
</div>
@@ -307,7 +307,7 @@ this parameter determines the key to retrieve the class index value</p></li>
<h3>ExperienceReplay<a class="headerlink" href="#experiencereplay" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.non_episodic.ExperienceReplay">
<em class="property">class </em><code class="descclassname">rl_coach.memories.non_episodic.</code><code class="descname">ExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool = True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/experience_replay.html#ExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.ExperienceReplay" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.non_episodic.</code><code class="sig-name descname">ExperienceReplay</code><span class="sig-paren">(</span><em class="sig-param">max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool = True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/experience_replay.html#ExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.ExperienceReplay" title="Permalink to this definition"></a></dt>
<dd><p>A regular replay buffer which stores transition without any additional structure</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -324,7 +324,7 @@ this parameter determines the key to retrieve the class index value</p></li>
<h3>PrioritizedExperienceReplay<a class="headerlink" href="#prioritizedexperiencereplay" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.non_episodic.PrioritizedExperienceReplay">
<em class="property">class </em><code class="descclassname">rl_coach.memories.non_episodic.</code><code class="descname">PrioritizedExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], alpha: float = 0.6, beta: rl_coach.schedules.Schedule = &lt;rl_coach.schedules.ConstantSchedule object&gt;, epsilon: float = 1e-06, allow_duplicates_in_batch_sampling: bool = True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/prioritized_experience_replay.html#PrioritizedExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.PrioritizedExperienceReplay" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.non_episodic.</code><code class="sig-name descname">PrioritizedExperienceReplay</code><span class="sig-paren">(</span><em class="sig-param">max_size: Tuple[rl_coach.memories.memory.MemoryGranularity, int], alpha: float = 0.6, beta: rl_coach.schedules.Schedule = &lt;rl_coach.schedules.ConstantSchedule object&gt;, epsilon: float = 1e-06, allow_duplicates_in_batch_sampling: bool = True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/prioritized_experience_replay.html#PrioritizedExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.PrioritizedExperienceReplay" title="Permalink to this definition"></a></dt>
<dd><p>This is the proportional sampling variant of the prioritized experience replay as described
in <a class="reference external" href="https://arxiv.org/pdf/1511.05952.pdf">https://arxiv.org/pdf/1511.05952.pdf</a>.</p>
<dl class="field-list simple">
@@ -345,7 +345,7 @@ in <a class="reference external" href="https://arxiv.org/pdf/1511.05952.pdf">htt
<h3>TransitionCollection<a class="headerlink" href="#transitioncollection" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.memories.non_episodic.TransitionCollection">
<em class="property">class </em><code class="descclassname">rl_coach.memories.non_episodic.</code><code class="descname">TransitionCollection</code><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/transition_collection.html#TransitionCollection"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.TransitionCollection" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.non_episodic.</code><code class="sig-name descname">TransitionCollection</code><a class="reference internal" href="../../_modules/rl_coach/memories/non_episodic/transition_collection.html#TransitionCollection"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.non_episodic.TransitionCollection" title="Permalink to this definition"></a></dt>
<dd><p>Simple python implementation of transitions collection non-episodic memories
are constructed on top of.</p>
</dd></dl>
+1 -1
View File
@@ -193,7 +193,7 @@
<h2>RedisPubSubBackend<a class="headerlink" href="#redispubsubbackend" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.memories.backend.redis.RedisPubSubBackend">
<em class="property">class </em><code class="descclassname">rl_coach.memories.backend.redis.</code><code class="descname">RedisPubSubBackend</code><span class="sig-paren">(</span><em>params: rl_coach.memories.backend.redis.RedisPubSubMemoryBackendParameters</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/backend/redis.html#RedisPubSubBackend"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.backend.redis.RedisPubSubBackend" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.memories.backend.redis.</code><code class="sig-name descname">RedisPubSubBackend</code><span class="sig-paren">(</span><em class="sig-param">params: rl_coach.memories.backend.redis.RedisPubSubMemoryBackendParameters</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/backend/redis.html#RedisPubSubBackend"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.backend.redis.RedisPubSubBackend" title="Permalink to this definition"></a></dt>
<dd><p>A memory backend which transfers the experiences from the rollout to the training worker using Redis Pub/Sub in
Coach when distributed mode is used.</p>
<dl class="field-list simple">
+1 -1
View File
@@ -193,7 +193,7 @@
<h2>Kubernetes<a class="headerlink" href="#kubernetes" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.orchestrators.kubernetes_orchestrator.Kubernetes">
<em class="property">class </em><code class="descclassname">rl_coach.orchestrators.kubernetes_orchestrator.</code><code class="descname">Kubernetes</code><span class="sig-paren">(</span><em>params: rl_coach.orchestrators.kubernetes_orchestrator.KubernetesParameters</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/orchestrators/kubernetes_orchestrator.html#Kubernetes"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.orchestrators.kubernetes_orchestrator.Kubernetes" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.orchestrators.kubernetes_orchestrator.</code><code class="sig-name descname">Kubernetes</code><span class="sig-paren">(</span><em class="sig-param">params: rl_coach.orchestrators.kubernetes_orchestrator.KubernetesParameters</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/orchestrators/kubernetes_orchestrator.html#Kubernetes"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.orchestrators.kubernetes_orchestrator.Kubernetes" title="Permalink to this definition"></a></dt>
<dd><p>An orchestrator implmentation which uses Kubernetes to deploy the components such as training and rollout workers
and Redis Pub/Sub in Coach when used in the distributed mode.</p>
<dl class="field-list simple">
+32 -32
View File
@@ -208,7 +208,7 @@
<h2>Space<a class="headerlink" href="#space" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.spaces.Space">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">Space</code><span class="sig-paren">(</span><em>shape: Union[int, tuple, list, numpy.ndarray], low: Union[None, int, float, numpy.ndarray] = -inf, high: Union[None, int, float, numpy.ndarray] = inf</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#Space"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">Space</code><span class="sig-paren">(</span><em class="sig-param">shape: Union[int, tuple, list, numpy.ndarray], low: Union[None, int, float, numpy.ndarray] = -inf, high: Union[None, int, float, numpy.ndarray] = inf</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#Space"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space" title="Permalink to this definition"></a></dt>
<dd><p>A space defines a set of valid values</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -223,7 +223,7 @@ or a single value defining the general highest values</p></li>
</dl>
<dl class="method">
<dt id="rl_coach.spaces.Space.contains">
<code class="descname">contains</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.contains"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.contains" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">contains</code><span class="sig-paren">(</span><em class="sig-param">val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.contains"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.contains" title="Permalink to this definition"></a></dt>
<dd><p>Checks if value is contained by this space. The shape must match and
all of the values must be within the low and high bounds.</p>
<dl class="field-list simple">
@@ -238,7 +238,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.Space.is_valid_index">
<code class="descname">is_valid_index</code><span class="sig-paren">(</span><em>index: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.is_valid_index"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.is_valid_index" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">is_valid_index</code><span class="sig-paren">(</span><em class="sig-param">index: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.is_valid_index"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.is_valid_index" title="Permalink to this definition"></a></dt>
<dd><p>Checks if a given multidimensional index is within the bounds of the shape of the space</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -252,7 +252,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.Space.sample">
<code class="descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.sample"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.sample" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="reference internal" href="../_modules/rl_coach/spaces.html#Space.sample"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.Space.sample" title="Permalink to this definition"></a></dt>
<dd><p>Sample the defined space, either uniformly, if space bounds are defined, or Normal distributed if no
bounds are defined</p>
<dl class="field-list simple">
@@ -269,10 +269,10 @@ bounds are defined</p>
<h2>Observation Spaces<a class="headerlink" href="#observation-spaces" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.spaces.ObservationSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">ObservationSpace</code><span class="sig-paren">(</span><em>shape: Union[int, numpy.ndarray], low: Union[None, int, float, numpy.ndarray] = -inf, high: Union[None, int, float, numpy.ndarray] = inf</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#ObservationSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ObservationSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">ObservationSpace</code><span class="sig-paren">(</span><em class="sig-param">shape: Union[int, numpy.ndarray], low: Union[None, int, float, numpy.ndarray] = -inf, high: Union[None, int, float, numpy.ndarray] = inf</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#ObservationSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ObservationSpace" title="Permalink to this definition"></a></dt>
<dd><dl class="method">
<dt id="rl_coach.spaces.ObservationSpace.contains">
<code class="descname">contains</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.contains" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">contains</code><span class="sig-paren">(</span><em class="sig-param">val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.contains" title="Permalink to this definition"></a></dt>
<dd><p>Checks if value is contained by this space. The shape must match and
all of the values must be within the low and high bounds.</p>
<dl class="field-list simple">
@@ -287,7 +287,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.ObservationSpace.is_valid_index">
<code class="descname">is_valid_index</code><span class="sig-paren">(</span><em>index: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.is_valid_index" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">is_valid_index</code><span class="sig-paren">(</span><em class="sig-param">index: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.is_valid_index" title="Permalink to this definition"></a></dt>
<dd><p>Checks if a given multidimensional index is within the bounds of the shape of the space</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -301,7 +301,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.ObservationSpace.sample">
<code class="descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.sample" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="headerlink" href="#rl_coach.spaces.ObservationSpace.sample" title="Permalink to this definition"></a></dt>
<dd><p>Sample the defined space, either uniformly, if space bounds are defined, or Normal distributed if no
bounds are defined</p>
<dl class="field-list simple">
@@ -317,7 +317,7 @@ bounds are defined</p>
<h3>VectorObservationSpace<a class="headerlink" href="#vectorobservationspace" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.spaces.VectorObservationSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">VectorObservationSpace</code><span class="sig-paren">(</span><em>shape: int</em>, <em>low: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray] = -inf</em>, <em>high: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray] = inf</em>, <em>measurements_names: List[str] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#VectorObservationSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.VectorObservationSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">VectorObservationSpace</code><span class="sig-paren">(</span><em class="sig-param">shape: int</em>, <em class="sig-param">low: Union[None</em>, <em class="sig-param">int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray] = -inf</em>, <em class="sig-param">high: Union[None</em>, <em class="sig-param">int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray] = inf</em>, <em class="sig-param">measurements_names: List[str] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#VectorObservationSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.VectorObservationSpace" title="Permalink to this definition"></a></dt>
<dd><p>An observation space which is defined as a vector of elements. This can be particularly useful for environments
which return measurements, such as in robotic environments.</p>
</dd></dl>
@@ -327,7 +327,7 @@ which return measurements, such as in robotic environments.</p>
<h3>PlanarMapsObservationSpace<a class="headerlink" href="#planarmapsobservationspace" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.spaces.PlanarMapsObservationSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">PlanarMapsObservationSpace</code><span class="sig-paren">(</span><em>shape: numpy.ndarray</em>, <em>low: int</em>, <em>high: int</em>, <em>channels_axis: int = -1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#PlanarMapsObservationSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.PlanarMapsObservationSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">PlanarMapsObservationSpace</code><span class="sig-paren">(</span><em class="sig-param">shape: numpy.ndarray</em>, <em class="sig-param">low: int</em>, <em class="sig-param">high: int</em>, <em class="sig-param">channels_axis: int = -1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#PlanarMapsObservationSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.PlanarMapsObservationSpace" title="Permalink to this definition"></a></dt>
<dd><p>An observation space which defines a stack of 2D observations. For example, an environment which returns
a stack of segmentation maps like in Starcraft.</p>
</dd></dl>
@@ -337,7 +337,7 @@ a stack of segmentation maps like in Starcraft.</p>
<h3>ImageObservationSpace<a class="headerlink" href="#imageobservationspace" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.spaces.ImageObservationSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">ImageObservationSpace</code><span class="sig-paren">(</span><em>shape: numpy.ndarray</em>, <em>high: int</em>, <em>channels_axis: int = -1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#ImageObservationSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ImageObservationSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">ImageObservationSpace</code><span class="sig-paren">(</span><em class="sig-param">shape: numpy.ndarray</em>, <em class="sig-param">high: int</em>, <em class="sig-param">channels_axis: int = -1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#ImageObservationSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ImageObservationSpace" title="Permalink to this definition"></a></dt>
<dd><p>An observation space which is a private case of the PlanarMapsObservationSpace, where the stack of 2D observations
represent a RGB image, or a grayscale image.</p>
</dd></dl>
@@ -348,10 +348,10 @@ represent a RGB image, or a grayscale image.</p>
<h2>Action Spaces<a class="headerlink" href="#action-spaces" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.spaces.ActionSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">ActionSpace</code><span class="sig-paren">(</span><em>shape: Union[int, numpy.ndarray], low: Union[None, int, float, numpy.ndarray] = -inf, high: Union[None, int, float, numpy.ndarray] = inf, descriptions: Union[None, List, Dict] = None, default_action: Union[int, float, numpy.ndarray, List] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#ActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ActionSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">ActionSpace</code><span class="sig-paren">(</span><em class="sig-param">shape: Union[int, numpy.ndarray], low: Union[None, int, float, numpy.ndarray] = -inf, high: Union[None, int, float, numpy.ndarray] = inf, descriptions: Union[None, List, Dict] = None, default_action: Union[int, float, numpy.ndarray, List] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#ActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ActionSpace" title="Permalink to this definition"></a></dt>
<dd><dl class="method">
<dt id="rl_coach.spaces.ActionSpace.clip_action_to_space">
<code class="descname">clip_action_to_space</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="reference internal" href="../_modules/rl_coach/spaces.html#ActionSpace.clip_action_to_space"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ActionSpace.clip_action_to_space" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">clip_action_to_space</code><span class="sig-paren">(</span><em class="sig-param">action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="reference internal" href="../_modules/rl_coach/spaces.html#ActionSpace.clip_action_to_space"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ActionSpace.clip_action_to_space" title="Permalink to this definition"></a></dt>
<dd><p>Given an action, clip its values to fit to the action space ranges</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -365,7 +365,7 @@ represent a RGB image, or a grayscale image.</p>
<dl class="method">
<dt id="rl_coach.spaces.ActionSpace.contains">
<code class="descname">contains</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ActionSpace.contains" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">contains</code><span class="sig-paren">(</span><em class="sig-param">val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ActionSpace.contains" title="Permalink to this definition"></a></dt>
<dd><p>Checks if value is contained by this space. The shape must match and
all of the values must be within the low and high bounds.</p>
<dl class="field-list simple">
@@ -380,7 +380,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.ActionSpace.is_valid_index">
<code class="descname">is_valid_index</code><span class="sig-paren">(</span><em>index: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ActionSpace.is_valid_index" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">is_valid_index</code><span class="sig-paren">(</span><em class="sig-param">index: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.ActionSpace.is_valid_index" title="Permalink to this definition"></a></dt>
<dd><p>Checks if a given multidimensional index is within the bounds of the shape of the space</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -394,7 +394,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.ActionSpace.sample">
<code class="descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="headerlink" href="#rl_coach.spaces.ActionSpace.sample" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="headerlink" href="#rl_coach.spaces.ActionSpace.sample" title="Permalink to this definition"></a></dt>
<dd><p>Sample the defined space, either uniformly, if space bounds are defined, or Normal distributed if no
bounds are defined</p>
<dl class="field-list simple">
@@ -406,7 +406,7 @@ bounds are defined</p>
<dl class="method">
<dt id="rl_coach.spaces.ActionSpace.sample_with_info">
<code class="descname">sample_with_info</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="reference internal" href="../_modules/rl_coach/spaces.html#ActionSpace.sample_with_info"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ActionSpace.sample_with_info" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">sample_with_info</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="reference internal" href="../_modules/rl_coach/spaces.html#ActionSpace.sample_with_info"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.ActionSpace.sample_with_info" title="Permalink to this definition"></a></dt>
<dd><p>Get a random action with additional “fake” info</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -421,7 +421,7 @@ bounds are defined</p>
<h3>AttentionActionSpace<a class="headerlink" href="#attentionactionspace" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.spaces.AttentionActionSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">AttentionActionSpace</code><span class="sig-paren">(</span><em>shape: int</em>, <em>low: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray] = -inf</em>, <em>high: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray] = inf</em>, <em>descriptions: Union[None</em>, <em>List</em>, <em>Dict] = None</em>, <em>default_action: numpy.ndarray = None</em>, <em>forced_attention_size: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#AttentionActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.AttentionActionSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">AttentionActionSpace</code><span class="sig-paren">(</span><em class="sig-param">shape: int</em>, <em class="sig-param">low: Union[None</em>, <em class="sig-param">int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray] = -inf</em>, <em class="sig-param">high: Union[None</em>, <em class="sig-param">int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray] = inf</em>, <em class="sig-param">descriptions: Union[None</em>, <em class="sig-param">List</em>, <em class="sig-param">Dict] = None</em>, <em class="sig-param">default_action: numpy.ndarray = None</em>, <em class="sig-param">forced_attention_size: Union[None</em>, <em class="sig-param">int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray] = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#AttentionActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.AttentionActionSpace" title="Permalink to this definition"></a></dt>
<dd><p>A box selection continuous action space, meaning that the actions are defined as selecting a multidimensional box
from a given range.
The actions will be in the form:
@@ -433,7 +433,7 @@ The actions will be in the form:
<h3>BoxActionSpace<a class="headerlink" href="#boxactionspace" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.spaces.BoxActionSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">BoxActionSpace</code><span class="sig-paren">(</span><em>shape: Union[int, numpy.ndarray], low: Union[None, int, float, numpy.ndarray] = -inf, high: Union[None, int, float, numpy.ndarray] = inf, descriptions: Union[None, List, Dict] = None, default_action: numpy.ndarray = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#BoxActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.BoxActionSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">BoxActionSpace</code><span class="sig-paren">(</span><em class="sig-param">shape: Union[int, numpy.ndarray], low: Union[None, int, float, numpy.ndarray] = -inf, high: Union[None, int, float, numpy.ndarray] = inf, descriptions: Union[None, List, Dict] = None, default_action: numpy.ndarray = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#BoxActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.BoxActionSpace" title="Permalink to this definition"></a></dt>
<dd><p>A multidimensional bounded or unbounded continuous action space</p>
</dd></dl>
@@ -442,7 +442,7 @@ The actions will be in the form:
<h3>DiscreteActionSpace<a class="headerlink" href="#discreteactionspace" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.spaces.DiscreteActionSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">DiscreteActionSpace</code><span class="sig-paren">(</span><em>num_actions: int</em>, <em>descriptions: Union[None</em>, <em>List</em>, <em>Dict] = None</em>, <em>default_action: numpy.ndarray = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#DiscreteActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.DiscreteActionSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">DiscreteActionSpace</code><span class="sig-paren">(</span><em class="sig-param">num_actions: int</em>, <em class="sig-param">descriptions: Union[None</em>, <em class="sig-param">List</em>, <em class="sig-param">Dict] = None</em>, <em class="sig-param">default_action: numpy.ndarray = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#DiscreteActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.DiscreteActionSpace" title="Permalink to this definition"></a></dt>
<dd><p>A discrete action space with action indices as actions</p>
</dd></dl>
@@ -451,7 +451,7 @@ The actions will be in the form:
<h3>MultiSelectActionSpace<a class="headerlink" href="#multiselectactionspace" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.spaces.MultiSelectActionSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">MultiSelectActionSpace</code><span class="sig-paren">(</span><em>size: int</em>, <em>max_simultaneous_selected_actions: int = 1</em>, <em>descriptions: Union[None</em>, <em>List</em>, <em>Dict] = None</em>, <em>default_action: numpy.ndarray = None</em>, <em>allow_no_action_to_be_selected=True</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#MultiSelectActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.MultiSelectActionSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">MultiSelectActionSpace</code><span class="sig-paren">(</span><em class="sig-param">size: int</em>, <em class="sig-param">max_simultaneous_selected_actions: int = 1</em>, <em class="sig-param">descriptions: Union[None</em>, <em class="sig-param">List</em>, <em class="sig-param">Dict] = None</em>, <em class="sig-param">default_action: numpy.ndarray = None</em>, <em class="sig-param">allow_no_action_to_be_selected=True</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#MultiSelectActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.MultiSelectActionSpace" title="Permalink to this definition"></a></dt>
<dd><p>A discrete action space where multiple actions can be selected at once. The actions are encoded as multi-hot vectors</p>
</dd></dl>
@@ -460,7 +460,7 @@ The actions will be in the form:
<h3>CompoundActionSpace<a class="headerlink" href="#compoundactionspace" title="Permalink to this headline"></a></h3>
<dl class="class">
<dt id="rl_coach.spaces.CompoundActionSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">CompoundActionSpace</code><span class="sig-paren">(</span><em>sub_spaces: List[rl_coach.spaces.ActionSpace]</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#CompoundActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.CompoundActionSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">CompoundActionSpace</code><span class="sig-paren">(</span><em class="sig-param">sub_spaces: List[rl_coach.spaces.ActionSpace]</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#CompoundActionSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.CompoundActionSpace" title="Permalink to this definition"></a></dt>
<dd><p>An action space which consists of multiple sub-action spaces.
For example, in Starcraft the agent should choose an action identifier from ~550 options (Discrete(550)),
but it also needs to choose 13 different arguments for the selected action identifier, where each argument is
@@ -473,7 +473,7 @@ by itself an action space. In Starcraft, the arguments are Discrete action space
<h2>Goal Spaces<a class="headerlink" href="#goal-spaces" title="Permalink to this headline"></a></h2>
<dl class="class">
<dt id="rl_coach.spaces.GoalsSpace">
<em class="property">class </em><code class="descclassname">rl_coach.spaces.</code><code class="descname">GoalsSpace</code><span class="sig-paren">(</span><em>goal_name: str, reward_type: rl_coach.spaces.GoalToRewardConversion, distance_metric: Union[rl_coach.spaces.GoalsSpace.DistanceMetric, Callable]</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.spaces.</code><code class="sig-name descname">GoalsSpace</code><span class="sig-paren">(</span><em class="sig-param">goal_name: str, reward_type: rl_coach.spaces.GoalToRewardConversion, distance_metric: Union[rl_coach.spaces.GoalsSpace.DistanceMetric, Callable]</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace" title="Permalink to this definition"></a></dt>
<dd><p>A multidimensional space with a goal type definition. It also behaves as an action space, so that hierarchical
agents can use it as an output action space.
The class acts as a wrapper to the target space. So after setting the target space, all the values of the class
@@ -491,13 +491,13 @@ returns the distance between them</p></li>
</dl>
<dl class="class">
<dt id="rl_coach.spaces.GoalsSpace.DistanceMetric">
<em class="property">class </em><code class="descname">DistanceMetric</code><a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace.DistanceMetric"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace.DistanceMetric" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-name descname">DistanceMetric</code><a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace.DistanceMetric"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace.DistanceMetric" title="Permalink to this definition"></a></dt>
<dd><p>An enumeration.</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.clip_action_to_space">
<code class="descname">clip_action_to_space</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.clip_action_to_space" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">clip_action_to_space</code><span class="sig-paren">(</span><em class="sig-param">action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; Union[int, float, numpy.ndarray, List]<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.clip_action_to_space" title="Permalink to this definition"></a></dt>
<dd><p>Given an action, clip its values to fit to the action space ranges</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -511,7 +511,7 @@ returns the distance between them</p></li>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.contains">
<code class="descname">contains</code><span class="sig-paren">(</span><em>val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.contains" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">contains</code><span class="sig-paren">(</span><em class="sig-param">val: Union[int, float, numpy.ndarray]</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.contains" title="Permalink to this definition"></a></dt>
<dd><p>Checks if value is contained by this space. The shape must match and
all of the values must be within the low and high bounds.</p>
<dl class="field-list simple">
@@ -526,7 +526,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.distance_from_goal">
<code class="descname">distance_from_goal</code><span class="sig-paren">(</span><em>goal: numpy.ndarray</em>, <em>state: dict</em><span class="sig-paren">)</span> &#x2192; float<a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace.distance_from_goal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace.distance_from_goal" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">distance_from_goal</code><span class="sig-paren">(</span><em class="sig-param">goal: numpy.ndarray</em>, <em class="sig-param">state: dict</em><span class="sig-paren">)</span> &#x2192; float<a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace.distance_from_goal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace.distance_from_goal" title="Permalink to this definition"></a></dt>
<dd><p>Given a state, check its distance from the goal</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -543,7 +543,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.get_reward_for_goal_and_state">
<code class="descname">get_reward_for_goal_and_state</code><span class="sig-paren">(</span><em>goal: numpy.ndarray</em>, <em>state: dict</em><span class="sig-paren">)</span> &#x2192; Tuple[float, bool]<a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace.get_reward_for_goal_and_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace.get_reward_for_goal_and_state" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_reward_for_goal_and_state</code><span class="sig-paren">(</span><em class="sig-param">goal: numpy.ndarray</em>, <em class="sig-param">state: dict</em><span class="sig-paren">)</span> &#x2192; Tuple[float, bool]<a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace.get_reward_for_goal_and_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace.get_reward_for_goal_and_state" title="Permalink to this definition"></a></dt>
<dd><p>Given a state, check if the goal was reached and return a reward accordingly</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -560,7 +560,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.goal_from_state">
<code class="descname">goal_from_state</code><span class="sig-paren">(</span><em>state: Dict</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace.goal_from_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace.goal_from_state" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">goal_from_state</code><span class="sig-paren">(</span><em class="sig-param">state: Dict</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/spaces.html#GoalsSpace.goal_from_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.spaces.GoalsSpace.goal_from_state" title="Permalink to this definition"></a></dt>
<dd><p>Given a state, extract an observation according to the goal_name</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -574,7 +574,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.is_valid_index">
<code class="descname">is_valid_index</code><span class="sig-paren">(</span><em>index: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.is_valid_index" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">is_valid_index</code><span class="sig-paren">(</span><em class="sig-param">index: numpy.ndarray</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.is_valid_index" title="Permalink to this definition"></a></dt>
<dd><p>Checks if a given multidimensional index is within the bounds of the shape of the space</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -588,7 +588,7 @@ all of the values must be within the low and high bounds.</p>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.sample">
<code class="descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.sample" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">sample</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.sample" title="Permalink to this definition"></a></dt>
<dd><p>Sample the defined space, either uniformly, if space bounds are defined, or Normal distributed if no
bounds are defined</p>
<dl class="field-list simple">
@@ -600,7 +600,7 @@ bounds are defined</p>
<dl class="method">
<dt id="rl_coach.spaces.GoalsSpace.sample_with_info">
<code class="descname">sample_with_info</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.sample_with_info" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">sample_with_info</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.spaces.GoalsSpace.sample_with_info" title="Permalink to this definition"></a></dt>
<dd><p>Get a random action with additional “fake” info</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
+36 -14
View File
@@ -221,7 +221,7 @@
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.act">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.action_space">action_space (rl_coach.environments.environment.Environment attribute)</a>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.action_space">action_space() (rl_coach.environments.environment.Environment property)</a>
</li>
<li><a href="components/core_types.html#rl_coach.core_types.ActionInfo">ActionInfo (class in rl_coach.core_types)</a>
</li>
@@ -408,6 +408,14 @@
<h2 id="F">F</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.freeze_memory">freeze_memory() (rl_coach.agents.agent.Agent method)</a>
<ul>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.freeze_memory">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/filters/output_filters.html#rl_coach.filters.action.FullDiscreteActionSpaceMap">FullDiscreteActionSpaceMap (class in rl_coach.filters.action)</a>
</li>
@@ -461,7 +469,7 @@
</li>
<li><a href="components/spaces.html#rl_coach.spaces.GoalsSpace.goal_from_state">goal_from_state() (rl_coach.spaces.GoalsSpace method)</a>
</li>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.goal_space">goal_space (rl_coach.environments.environment.Environment attribute)</a>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.goal_space">goal_space() (rl_coach.environments.environment.Environment property)</a>
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Batch.goals">goals() (rl_coach.core_types.Batch method)</a>
</li>
@@ -505,6 +513,12 @@
<ul>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.initialize_session_dependent_components">initialize_session_dependent_components() (rl_coach.agents.agent.Agent method)</a>
<ul>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.initialize_session_dependent_components">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
</ul></td>
@@ -537,7 +551,7 @@
<h2 id="L">L</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.last_env_response">last_env_response (rl_coach.environments.environment.Environment attribute)</a>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.last_env_response">last_env_response() (rl_coach.environments.environment.Environment property)</a>
</li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.learn_from_batch">learn_from_batch() (rl_coach.agents.agent.Agent method)</a>
@@ -545,12 +559,18 @@
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.learn_from_batch">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/core_types.html#rl_coach.core_types.Episode.length">length() (rl_coach.core_types.Episode method)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/filters/output_filters.html#rl_coach.filters.action.LinearBoxToBoxMap">LinearBoxToBoxMap (class in rl_coach.filters.action)</a>
</li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.load_memory_from_file">load_memory_from_file() (rl_coach.agents.agent.Agent method)</a>
<ul>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.load_memory_from_file">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.log_to_screen">log_to_screen() (rl_coach.agents.agent.Agent method)</a>
<ul>
@@ -647,20 +667,20 @@
</li>
<li><a href="components/exploration_policies/index.html#rl_coach.exploration_policies.parameter_noise.ParameterNoise">ParameterNoise (class in rl_coach.exploration_policies.parameter_noise)</a>
</li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.parent">parent (rl_coach.agents.agent.Agent attribute)</a>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.parent">parent() (rl_coach.agents.agent.Agent property)</a>
<ul>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.parent">(rl_coach.agents.dqn_agent.DQNAgent attribute)</a>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.parent">(rl_coach.agents.dqn_agent.DQNAgent property)</a>
</li>
</ul></li>
<li><a href="components/filters/output_filters.html#rl_coach.filters.action.PartialDiscreteActionSpaceMap">PartialDiscreteActionSpaceMap (class in rl_coach.filters.action)</a>
</li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.phase">phase (rl_coach.agents.agent.Agent attribute)</a>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.phase">phase() (rl_coach.agents.agent.Agent property)</a>
<ul>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.phase">(rl_coach.agents.dqn_agent.DQNAgent attribute)</a>
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.phase">(rl_coach.agents.dqn_agent.DQNAgent property)</a>
</li>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.phase">(rl_coach.environments.environment.Environment attribute)</a>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.phase">(rl_coach.environments.environment.Environment property)</a>
</li>
</ul></li>
</ul></td>
@@ -835,7 +855,7 @@
</li>
<li><a href="components/memories/index.html#rl_coach.memories.episodic.SingleEpisodeBuffer">SingleEpisodeBuffer (class in rl_coach.memories.episodic)</a>
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Batch.size">size (rl_coach.core_types.Batch attribute)</a>
<li><a href="components/core_types.html#rl_coach.core_types.Batch.size">size() (rl_coach.core_types.Batch property)</a>
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Batch.slice">slice() (rl_coach.core_types.Batch method)</a>
</li>
@@ -845,7 +865,7 @@
</li>
<li><a href="components/environments/index.html#rl_coach.environments.starcraft2_environment.StarCraft2Environment">StarCraft2Environment (class in rl_coach.environments.starcraft2_environment)</a>
</li>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.state_space">state_space (rl_coach.environments.environment.Environment attribute)</a>
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.state_space">state_space() (rl_coach.environments.environment.Environment property)</a>
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Batch.states">states() (rl_coach.core_types.Batch method)</a>
</li>
@@ -866,6 +886,8 @@
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/additional_parameters.html#rl_coach.base_parameters.TaskParameters">TaskParameters (class in rl_coach.base_parameters)</a>
</li>
<li><a href="components/agents/policy_optimization/td3.html#rl_coach.agents.td3_agent.TD3AlgorithmParameters">TD3AlgorithmParameters (class in rl_coach.agents.td3_agent)</a>
</li>
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.train">train() (rl_coach.agents.agent.Agent method)</a>
@@ -873,10 +895,10 @@
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.train">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
</li>
</ul></li>
<li><a href="components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">train_and_sync_networks() (rl_coach.architectures.network_wrapper.NetworkWrapper method)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">train_and_sync_networks() (rl_coach.architectures.network_wrapper.NetworkWrapper method)</a>
</li>
<li><a href="components/architectures/index.html#rl_coach.architectures.architecture.Architecture.train_on_batch">train_on_batch() (rl_coach.architectures.architecture.Architecture method)</a>
</li>
<li><a href="components/core_types.html#rl_coach.core_types.Transition">Transition (class in rl_coach.core_types)</a>
BIN
View File
Binary file not shown.
+1 -1
View File
File diff suppressed because one or more lines are too long
+10
View File
@@ -391,6 +391,16 @@ $(document).ready(function() {
and therefore it is able to use a replay buffer in order to improve sample efficiency.
</span>
</div>
<div class="algorithm continuous off-policy" data-year="201509">
<span class="badge">
<a href="components/agents/policy_optimization/td3.html">TD3</a>
<br>
Very similar to DDPG, i.e. an actor-critic for continuous action spaces, that uses a replay buffer in
order to improve sample efficiency. TD3 uses two critic networks in order to mitigate the overestimation
in the Q state-action value prediction, slows down the actor updates in order to increase stability and
adds noise to actions while training the critic in order to smooth out the critic's predictions.
</span>
</div>
<div class="algorithm continuous discrete on-policy" data-year="201706">
<span class="badge">
<a href="components/agents/policy_optimization/ppo.html">PPO</a>
+65 -36
View File
@@ -190,10 +190,10 @@
</div>
<dl class="class">
<dt id="rl_coach.agents.dqn_agent.DQNAgent">
<em class="property">class </em><code class="descclassname">rl_coach.agents.dqn_agent.</code><code class="descname">DQNAgent</code><span class="sig-paren">(</span><em>agent_parameters</em>, <em>parent: Union[LevelManager</em>, <em>CompositeAgent] = None</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/rl_coach/agents/dqn_agent.html#DQNAgent"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent" title="Permalink to this definition"></a></dt>
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.agents.dqn_agent.</code><code class="sig-name descname">DQNAgent</code><span class="sig-paren">(</span><em class="sig-param">agent_parameters</em>, <em class="sig-param">parent: Union[LevelManager</em>, <em class="sig-param">CompositeAgent] = None</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/rl_coach/agents/dqn_agent.html#DQNAgent"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent" title="Permalink to this definition"></a></dt>
<dd><dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.act">
<code class="descname">act</code><span class="sig-paren">(</span><em>action: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List] = None</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">act</code><span class="sig-paren">(</span><em class="sig-param">action: Union[None</em>, <em class="sig-param">int</em>, <em class="sig-param">float</em>, <em class="sig-param">numpy.ndarray</em>, <em class="sig-param">List] = None</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition"></a></dt>
<dd><p>Given the agents current knowledge, decide on the next action to apply to the environment</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -207,7 +207,7 @@
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.call_memory">
<code class="descname">call_memory</code><span class="sig-paren">(</span><em>func</em>, <em>args=()</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.call_memory" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">call_memory</code><span class="sig-paren">(</span><em class="sig-param">func</em>, <em class="sig-param">args=()</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.call_memory" title="Permalink to this definition"></a></dt>
<dd><p>This function is a wrapper to allow having the same calls for shared or unshared memories.
It should be used instead of calling the memory directly in order to allow different algorithms to work
both with a shared and a local memory.</p>
@@ -226,7 +226,7 @@ both with a shared and a local memory.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.choose_action">
<code class="descname">choose_action</code><span class="sig-paren">(</span><em>curr_state</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.choose_action" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">choose_action</code><span class="sig-paren">(</span><em class="sig-param">curr_state</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.choose_action" title="Permalink to this definition"></a></dt>
<dd><p>choose an action to act with in the current episode being played. Different behavior might be exhibited when
training or testing.</p>
<dl class="field-list simple">
@@ -241,7 +241,7 @@ training or testing.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.collect_savers">
<code class="descname">collect_savers</code><span class="sig-paren">(</span><em>parent_path_suffix: str</em><span class="sig-paren">)</span> &#x2192; rl_coach.saver.SaverCollection<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.collect_savers" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">collect_savers</code><span class="sig-paren">(</span><em class="sig-param">parent_path_suffix: str</em><span class="sig-paren">)</span> &#x2192; rl_coach.saver.SaverCollection<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.collect_savers" title="Permalink to this definition"></a></dt>
<dd><p>Collect all of agents network savers
:param parent_path_suffix: path suffix of the parent of the agent
(could be name of level manager or composite agent)
@@ -250,7 +250,7 @@ training or testing.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.create_networks">
<code class="descname">create_networks</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Dict[str, rl_coach.architectures.network_wrapper.NetworkWrapper]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.create_networks" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">create_networks</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; Dict[str, rl_coach.architectures.network_wrapper.NetworkWrapper]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.create_networks" title="Permalink to this definition"></a></dt>
<dd><p>Create all the networks of the agent.
The network creation will be done after setting the environment parameters for the agent, since they are needed
for creating the network.</p>
@@ -261,9 +261,16 @@ for creating the network.</p>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.freeze_memory">
<code class="sig-name descname">freeze_memory</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.freeze_memory" title="Permalink to this definition"></a></dt>
<dd><p>Shuffle episodes in the memory and freeze it to make sure that no extra data is being pushed anymore.
:return: None</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.get_predictions">
<code class="descname">get_predictions</code><span class="sig-paren">(</span><em>states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.get_predictions" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_predictions</code><span class="sig-paren">(</span><em class="sig-param">states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.get_predictions" title="Permalink to this definition"></a></dt>
<dd><p>Get a prediction from the agent with regard to the requested prediction_type.
If the agent cannot predict this type of prediction_type, or if there is more than possible way to do so,
raise a ValueException.</p>
@@ -282,7 +289,7 @@ raise a ValueException.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.get_state_embedding">
<code class="descname">get_state_embedding</code><span class="sig-paren">(</span><em>state: dict</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.get_state_embedding" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">get_state_embedding</code><span class="sig-paren">(</span><em class="sig-param">state: dict</em><span class="sig-paren">)</span> &#x2192; numpy.ndarray<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.get_state_embedding" title="Permalink to this definition"></a></dt>
<dd><p>Given a state, get the corresponding state embedding from the main network</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -296,7 +303,7 @@ raise a ValueException.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.handle_episode_ended">
<code class="descname">handle_episode_ended</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.handle_episode_ended" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">handle_episode_ended</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.handle_episode_ended" title="Permalink to this definition"></a></dt>
<dd><p>Make any changes needed when each episode is ended.
This includes incrementing counters, updating full episode dependent values, updating logs, etc.
This function is called right after each episode is ended.</p>
@@ -309,7 +316,7 @@ This function is called right after each episode is ended.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model">
<code class="descname">improve_reward_model</code><span class="sig-paren">(</span><em>epochs: int</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">improve_reward_model</code><span class="sig-paren">(</span><em class="sig-param">epochs: int</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model" title="Permalink to this definition"></a></dt>
<dd><p>Train a reward model to be used by the doubly-robust estimator</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -323,7 +330,7 @@ This function is called right after each episode is ended.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules">
<code class="descname">init_environment_dependent_modules</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">init_environment_dependent_modules</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules" title="Permalink to this definition"></a></dt>
<dd><p>Initialize any modules that depend on knowing information about the environment such as the action space or
the observation space</p>
<dl class="field-list simple">
@@ -333,9 +340,20 @@ the observation space</p>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.initialize_session_dependent_components">
<code class="sig-name descname">initialize_session_dependent_components</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.initialize_session_dependent_components" title="Permalink to this definition"></a></dt>
<dd><p>Initialize components which require a session as part of their initialization.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>None</p>
</dd>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.learn_from_batch">
<code class="descname">learn_from_batch</code><span class="sig-paren">(</span><em>batch</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/rl_coach/agents/dqn_agent.html#DQNAgent.learn_from_batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.learn_from_batch" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">learn_from_batch</code><span class="sig-paren">(</span><em class="sig-param">batch</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/rl_coach/agents/dqn_agent.html#DQNAgent.learn_from_batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.learn_from_batch" title="Permalink to this definition"></a></dt>
<dd><p>Given a batch of transitions, calculates their target values and updates the network.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -347,9 +365,20 @@ the observation space</p>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.load_memory_from_file">
<code class="sig-name descname">load_memory_from_file</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.load_memory_from_file" title="Permalink to this definition"></a></dt>
<dd><p>Load memory transitions from a file.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>None</p>
</dd>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.log_to_screen">
<code class="descname">log_to_screen</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.log_to_screen" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">log_to_screen</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.log_to_screen" title="Permalink to this definition"></a></dt>
<dd><p>Write an episode summary line to the terminal</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -360,7 +389,7 @@ the observation space</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.observe">
<code class="descname">observe</code><span class="sig-paren">(</span><em>env_response: rl_coach.core_types.EnvResponse</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.observe" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">observe</code><span class="sig-paren">(</span><em class="sig-param">env_response: rl_coach.core_types.EnvResponse</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.observe" title="Permalink to this definition"></a></dt>
<dd><p>Given a response from the environment, distill the observation from it and store it for later use.
The response should be a dictionary containing the performed action, the new observation and measurements,
the reward, a game over flag and any additional information necessary.</p>
@@ -375,9 +404,9 @@ given observation</p>
</dl>
</dd></dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.parent">
<code class="descname">parent</code><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.parent" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">parent</code><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.parent" title="Permalink to this definition"></a></dt>
<dd><p>Get the parent class of the agent</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -386,9 +415,9 @@ given observation</p>
</dl>
</dd></dl>
<dl class="attribute">
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.phase">
<code class="descname">phase</code><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.phase" title="Permalink to this definition"></a></dt>
<em class="property">property </em><code class="sig-name descname">phase</code><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.phase" title="Permalink to this definition"></a></dt>
<dd><p>The current running phase of the agent</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -399,7 +428,7 @@ given observation</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.post_training_commands">
<code class="descname">post_training_commands</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.post_training_commands" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">post_training_commands</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.post_training_commands" title="Permalink to this definition"></a></dt>
<dd><p>A function which allows adding any functionality that is required to run right after the training phase ends.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -410,7 +439,7 @@ given observation</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference">
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em class="sig-param">states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.core.multiarray.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.</p>
<dl class="field-list simple">
@@ -430,7 +459,7 @@ the observation relevant for the network from the states.</p></li>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.register_signal">
<code class="descname">register_signal</code><span class="sig-paren">(</span><em>signal_name: str</em>, <em>dump_one_value_per_episode: bool = True</em>, <em>dump_one_value_per_step: bool = False</em><span class="sig-paren">)</span> &#x2192; rl_coach.utils.Signal<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.register_signal" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">register_signal</code><span class="sig-paren">(</span><em class="sig-param">signal_name: str</em>, <em class="sig-param">dump_one_value_per_episode: bool = True</em>, <em class="sig-param">dump_one_value_per_step: bool = False</em><span class="sig-paren">)</span> &#x2192; rl_coach.utils.Signal<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.register_signal" title="Permalink to this definition"></a></dt>
<dd><p>Register a signal such that its statistics will be dumped and be viewable through dashboard</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -448,7 +477,7 @@ the observation relevant for the network from the states.</p></li>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.reset_evaluation_state">
<code class="descname">reset_evaluation_state</code><span class="sig-paren">(</span><em>val: rl_coach.core_types.RunPhase</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.reset_evaluation_state" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">reset_evaluation_state</code><span class="sig-paren">(</span><em class="sig-param">val: rl_coach.core_types.RunPhase</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.reset_evaluation_state" title="Permalink to this definition"></a></dt>
<dd><p>Perform accumulators initialization when entering an evaluation phase, and signal dumping when exiting an
evaluation phase. Entering or exiting the evaluation phase is determined according to the new phase given
by val, and by the current phase set in self.phase.</p>
@@ -464,7 +493,7 @@ by val, and by the current phase set in self.phase.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.reset_internal_state">
<code class="descname">reset_internal_state</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.reset_internal_state" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">reset_internal_state</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.reset_internal_state" title="Permalink to this definition"></a></dt>
<dd><p>Reset all the episodic parameters. This function is called right before each episode starts.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -475,7 +504,7 @@ by val, and by the current phase set in self.phase.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.restore_checkpoint">
<code class="descname">restore_checkpoint</code><span class="sig-paren">(</span><em>checkpoint_dir: str</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.restore_checkpoint" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">restore_checkpoint</code><span class="sig-paren">(</span><em class="sig-param">checkpoint_dir: str</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.restore_checkpoint" title="Permalink to this definition"></a></dt>
<dd><p>Allows agents to store additional information when saving checkpoints.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -489,7 +518,7 @@ by val, and by the current phase set in self.phase.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation">
<code class="descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation" title="Permalink to this definition"></a></dt>
<dd><p>Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on
an evaluation dataset, which was collected by another policy(ies).
:return: None</p>
@@ -497,7 +526,7 @@ an evaluation dataset, which was collected by another policy(ies).
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference">
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em class="sig-param">state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Run filters which where defined for being applied right before using the state for inference.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -514,7 +543,7 @@ an evaluation dataset, which was collected by another policy(ies).
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.save_checkpoint">
<code class="descname">save_checkpoint</code><span class="sig-paren">(</span><em>checkpoint_prefix: str</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.save_checkpoint" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">save_checkpoint</code><span class="sig-paren">(</span><em class="sig-param">checkpoint_prefix: str</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.save_checkpoint" title="Permalink to this definition"></a></dt>
<dd><p>Allows agents to store additional information when saving checkpoints.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -528,7 +557,7 @@ an evaluation dataset, which was collected by another policy(ies).
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.set_environment_parameters">
<code class="descname">set_environment_parameters</code><span class="sig-paren">(</span><em>spaces: rl_coach.spaces.SpacesDefinition</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.set_environment_parameters" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_environment_parameters</code><span class="sig-paren">(</span><em class="sig-param">spaces: rl_coach.spaces.SpacesDefinition</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.set_environment_parameters" title="Permalink to this definition"></a></dt>
<dd><p>Sets the parameters that are environment dependent. As a side effect, initializes all the components that are
dependent on those values, by calling init_environment_dependent_modules</p>
<dl class="field-list simple">
@@ -543,7 +572,7 @@ dependent on those values, by calling init_environment_dependent_modules</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.set_incoming_directive">
<code class="descname">set_incoming_directive</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.set_incoming_directive" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_incoming_directive</code><span class="sig-paren">(</span><em class="sig-param">action: Union[int, float, numpy.ndarray, List]</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.set_incoming_directive" title="Permalink to this definition"></a></dt>
<dd><p>Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent
has another master agent that is controlling it. In such cases, the master agent can define the goals for the
slave agent, define its observation, possible actions, etc. The directive type is defined by the agent
@@ -560,7 +589,7 @@ in-action-space.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.set_session">
<code class="descname">set_session</code><span class="sig-paren">(</span><em>sess</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.set_session" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">set_session</code><span class="sig-paren">(</span><em class="sig-param">sess</em><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.set_session" title="Permalink to this definition"></a></dt>
<dd><p>Set the deep learning framework session for all the agents in the composite agent</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -571,7 +600,7 @@ in-action-space.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.setup_logger">
<code class="descname">setup_logger</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.setup_logger" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">setup_logger</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.setup_logger" title="Permalink to this definition"></a></dt>
<dd><p>Setup the logger for the agent</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -582,7 +611,7 @@ in-action-space.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.sync">
<code class="descname">sync</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.sync" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">sync</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.sync" title="Permalink to this definition"></a></dt>
<dd><p>Sync the global network parameters to local networks</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -593,7 +622,7 @@ in-action-space.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.train">
<code class="descname">train</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; float<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.train" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">train</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; float<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.train" title="Permalink to this definition"></a></dt>
<dd><p>Check if a training phase should be done as configured by num_consecutive_playing_steps.
If it should, then do several training steps as configured by num_consecutive_training_steps.
A single training iteration: Sample a batch, train on it and update target networks.</p>
@@ -606,7 +635,7 @@ A single training iteration: Sample a batch, train on it and update target netwo
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.update_log">
<code class="descname">update_log</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.update_log" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">update_log</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.update_log" title="Permalink to this definition"></a></dt>
<dd><p>Updates the episodic log file with all the signal values from the most recent episode.
Additional signals for logging can be set by the creating a new signal using self.register_signal,
and then updating it with some internal agent values.</p>
@@ -619,7 +648,7 @@ and then updating it with some internal agent values.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.update_step_in_episode_log">
<code class="descname">update_step_in_episode_log</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.update_step_in_episode_log" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">update_step_in_episode_log</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; None<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.update_step_in_episode_log" title="Permalink to this definition"></a></dt>
<dd><p>Updates the in-episode log file with all the signal values from the most recent step.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
@@ -630,7 +659,7 @@ and then updating it with some internal agent values.</p>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.update_transition_before_adding_to_replay_buffer">
<code class="descname">update_transition_before_adding_to_replay_buffer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.update_transition_before_adding_to_replay_buffer" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">update_transition_before_adding_to_replay_buffer</code><span class="sig-paren">(</span><em class="sig-param">transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.Transition<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.update_transition_before_adding_to_replay_buffer" title="Permalink to this definition"></a></dt>
<dd><p>Allows agents to update the transition just before adding it to the replay buffer.
Can be useful for agents that want to tweak the reward, termination signal, etc.</p>
<dl class="field-list simple">
Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

After

Width:  |  Height:  |  Size: 60 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

+1
View File
@@ -0,0 +1 @@
<mxfile modified="2019-06-13T11:04:47.252Z" host="www.draw.io" agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3626.121 Safari/537.36" etag="OGV5teY4xcR0Xj7nvQBA" version="10.7.7" type="device"><diagram id="Fja6IZyvrddIfr-74nt7" name="Page-1">7V1bk5s2FP41O9M+JIMQ18e9NduZbCdp0jR5ysgg26SAXMC73v76CnMVEphdkLGzbB5sDhJgne/7dCQdkQt4HezeRWizvicu9i9Uxd1dwJsLVQVAt+hHannKLIZuZ4ZV5Ll5ocrwyfsP50Ylt249F8dMwYQQP/E2rNEhYYidhLGhKCKPbLEl8dm7btAKc4ZPDvJ569+em6wzq6Urlf0Oe6t1cWeg5GcCVBTODfEaueSxZoK3F/A6IiTJvgW7a+ynjVe0S1bvt5az5YNFOEz6VPh+fW+gL8GXH+p7HP5156+j3dc3mp4/XPJU/GLs0gbID0mUrMmKhMi/raxXEdmGLk4vq9Cjqsx7QjbUCKjxB06Sp9ybaJsQalongZ+fxTsv+ZpWf6vnR99qZ252+ZX3B0/FQZhET7VK6eG3+rmq2v6oqBcnEfmn9F16j+wXpz+ztSVzU0y2kYM7mi8ttYckilY46SgItdLjlCqYBJg+I60YYR8l3gP7JCjH7Kosl1e9jCL0VCuwIV6YxLUrf0gNtEBOP6jl2MvJpxsNiDTLm1ZXefole4LiqPZTKtMeds+AYP6jH5C/zZuBg2QFuNR/j2svwZ82aO+ZRyo7LLiWnu9fE59E+7pwqaf/SiDUzhj7v7QGCZOaPftLaxSMTXGEIifHs8qhSi1R9YCjBO+6ccWjoKigs+4ChXQ8Vsqj5i5R1jXVgYXoiJBT894LnKNwzpj14Rn60FceBqrDIB+rHAHJIsbRA701CYdxkWEQT0DHwoulkE0NErsIW0uHYSEYiXRU9N7qDO0MnnW2zpNOlcY5wDX6zLn+nIM9OTe0Rx7kYyjo9AyfPu7Vgn5ZJfsmyQxp78Sgwfh3S4oTb+K9Py9pAaBvdtXJ4iq/h5ttWv02WGDXxVFxVfrU2YXZm1Fz7QHkUd9SF3Df9x6mvo4tV2Oob4zEfMVqMB9AnvpAhzz3LVnc146Fi3vPdX38SBtrRkQNETaDBxoBC/Cg8njQZOFBHzk4PgMfADYILrnG+EAQBJfkHd0JhgRSQhEpb24+vKMnL52EtrOq3GHkDmEndULScD/j8ZCEuOHa3IR8bxXSQ4c6jfYZ8Cp1qecg/zI/EezlowVvbCzCQG4EgABDacg2FNBUEyBEXsSm/ewR24jxmdkzPrOmjM/MVsr3ZbgtInirbiyiUgNwGGMxvZtF/9gG9EZkm9AQL87uGm6D78hJB24xX/6AltQLyutjXNNeKPwgQNDHLJfYcCQN+lRWQVRL0NGropkWXZaEzBOxQ0TF6ikq9pSiYo0fRwhV5jMK754VN7w+/utqM4aA8C2vAYZo3gfKkgB7loABEmD3lAAwaWBhcxrwlfe673ubuC26rvkOxZtszXPp7VIMXLE8HZ2IYzDPZId30OSHd1C4xCGLdXBm3QDWaT1ZZ05JOmBO6WNQ83Dl70M+ZjxcOXwKHwPjHKIr0DlNcziUskSh1GXbWOpPFK56DtM47L2aeRlz4nkZYxjvwSvXdtB3VJXNSh87u6UctOdw06wD2S2ws7yk7BZ+0JfNz1BbgNFrXF/XtIkX2IsLzylHKrDYcFyzBSlHotUWaEjzjtraj/deXVFEfblDQgclrRMjLx2EncGYy26k9QHeyaZglVsaAeHYOX9nsKxpaKwPBLGRcGkZyPOCiGg/uRcaCbGmIMMSiDIsAZCW8WHNMeqQbK++6V5Z8DdZvld7wtcJLSjOC4Rt05S6LVgcMARhkaz1QW0eyg6Sib7zlNCYVCba8/9mmThBmYAaO7SdXCaKCHKWiZfJhN5TJrRJp7qhKC10lomTlYlmNGEIZsCOKxPDtpi8epnouyCmK2JcHEkm2lfEZpk4QZloRhOTy4Q2RxPH2R0+dP/ni9bPnr07XIFd5eWsn2nGDMH5BQWTvqBAH7a7YxoIqqeDQa1vZqY+iQyqjSVQ07QPyGBneUkYPJtVgiNAaShCDkLAMhurPRl081qNoGqMTo7PCp4kUBcG5WXlkN9uBOZwvmux0WSFQrC7/Lij/mF7jKbpyMzT6ciKPJgT7chgMatTdkwHginV7CovqSMT5V7JkLrxZiAGPVeXaM4SCUEzOXB6kTybASevdSewu1vvm49hqFJCOQ2wotY3lHu22AJ2MsTSDoitAbrKSxJbCa8ZEqrcvRceELXXvuHUMliZsy3BdtNipuEo20314YMOAERg+DjOZYqrxNsSOPWes2amnyhIMREu4vQj+0UH8diZ/vvTpeFr7MyZJdh3KdztLG1vjioJfoEX/jIBBGkrKB8FFVRxhV9nfLKRoAobkaAIoeZREdqeZVG68fMNlWfly0X6EM1OrfH6pkWrL/n9gCM0KLAbgXUxyVTfjCdI+i/W/MZvzh6v0RK+le7YDWexEaQhyBHXRB21tA0xZ7P+cJIjkhzPR9rEKNh12Nhh1Xx35EgDkuZuyOJF9G0DEq3xXM8tf+TdlvSwelV+Vrz6Dwfg7f8=</diagram></mxfile>
@@ -21,6 +21,7 @@ A detailed description of those algorithms can be found by navigating to each of
imitation/cil
policy_optimization/cppo
policy_optimization/ddpg
policy_optimization/td3
policy_optimization/sac
other/dfp
value_optimization/double_dqn
@@ -0,0 +1,55 @@
Twin Delayed Deep Deterministic Policy Gradient
==================================
**Actions space:** Continuous
**References:** `Addressing Function Approximation Error in Actor-Critic Methods <https://arxiv.org/pdf/1802.09477>`_
Network Structure
-----------------
.. image:: /_static/img/design_imgs/td3.png
:align: center
Algorithm Description
---------------------
Choosing an action
++++++++++++++++++
Pass the current states through the actor network, and get an action mean vector :math:`\mu`.
While in training phase, use a continuous exploration policy, such as a small zero-meaned gaussian noise,
to add exploration noise to the action. When testing, use the mean vector :math:`\mu` as-is.
Training the network
++++++++++++++++++++
Start by sampling a batch of transitions from the experience replay.
* To train the two **critic networks**, use the following targets:
:math:`y_t=r(s_t,a_t )+\gamma \cdot \min_{i=1,2} Q_{i}(s_{t+1},\mu(s_{t+1} )+[\mathcal{N}(0,\,\sigma^{2})]^{MAX\_NOISE}_{MIN\_NOISE})`
First run the actor target network, using the next states as the inputs, and get :math:`\mu (s_{t+1} )`. Then, add a
clipped gaussian noise to these actions, and clip the resulting actions to the actions space.
Next, run the critic target networks using the next states and :math:`\mu (s_{t+1} )+[\mathcal{N}(0,\,\sigma^{2})]^{MAX\_NOISE}_{MIN\_NOISE}`,
and use the minimum between the two critic networks predictions in order to calculate :math:`y_t` according to the
equation above. To train the networks, use the current states and actions as the inputs, and :math:`y_t`
as the targets.
* To train the **actor network**, use the following equation:
:math:`\nabla_{\theta^\mu } J \approx E_{s_t \tilde{} \rho^\beta } [\nabla_a Q_{1}(s,a)|_{s=s_t,a=\mu (s_t ) } \cdot \nabla_{\theta^\mu} \mu(s)|_{s=s_t} ]`
Use the actor's online network to get the action mean values using the current states as the inputs.
Then, use the first critic's online network in order to get the gradients of the critic output with respect to the
action mean values :math:`\nabla _a Q_{1}(s,a)|_{s=s_t,a=\mu(s_t ) }`.
Using the chain rule, calculate the gradients of the actor's output, with respect to the actor weights,
given :math:`\nabla_a Q(s,a)`. Finally, apply those gradients to the actor network.
The actor's training is done at a slower frequency than the critic's training, in order to allow the critic to better fit the
current policy, before exercising the critic in order to train the actor.
Following the same, delayed, actor's training cadence, do a soft update of the critic and actor target networks' weights
from the online networks.
.. autoclass:: rl_coach.agents.td3_agent.TD3AlgorithmParameters
File diff suppressed because one or more lines are too long
@@ -214,6 +214,16 @@ The algorithms are ordered by their release date in descending order.
and therefore it is able to use a replay buffer in order to improve sample efficiency.
</span>
</div>
<div class="algorithm continuous off-policy" data-year="201509">
<span class="badge">
<a href="components/agents/policy_optimization/td3.html">TD3</a>
<br>
Very similar to DDPG, i.e. an actor-critic for continuous action spaces, that uses a replay buffer in
order to improve sample efficiency. TD3 uses two critic networks in order to mitigate the overestimation
in the Q state-action value prediction, slows down the actor updates in order to increase stability and
adds noise to actions while training the critic in order to smooth out the critic's predictions.
</span>
</div>
<div class="algorithm continuous discrete on-policy" data-year="201706">
<span class="badge">
<a href="components/agents/policy_optimization/ppo.html">PPO</a>
Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

+4 -3
View File
@@ -41,14 +41,15 @@ class DDPGCriticNetworkParameters(NetworkParameters):
self.middleware_parameters = FCMiddlewareParameters()
self.heads_parameters = [DDPGVHeadParameters()]
self.optimizer_type = 'Adam'
self.adam_optimizer_beta2 = 0.999
self.optimizer_epsilon = 1e-8
self.batch_size = 64
self.async_training = False
self.learning_rate = 0.001
self.adam_optimizer_beta2 = 0.999
self.optimizer_epsilon = 1e-8
self.create_target_network = True
self.shared_optimizer = True
self.scale_down_gradients_by_number_of_workers_for_sync_training = False
# self.l2_regularization = 1e-2
class DDPGActorNetworkParameters(NetworkParameters):
@@ -58,9 +59,9 @@ class DDPGActorNetworkParameters(NetworkParameters):
self.middleware_parameters = FCMiddlewareParameters(batchnorm=True)
self.heads_parameters = [DDPGActorHeadParameters()]
self.optimizer_type = 'Adam'
self.batch_size = 64
self.adam_optimizer_beta2 = 0.999
self.optimizer_epsilon = 1e-8
self.batch_size = 64
self.async_training = False
self.learning_rate = 0.0001
self.create_target_network = True
+2 -2
View File
@@ -90,7 +90,7 @@ class DDQNBCQAgent(DQNAgent):
if self.ap.algorithm.action_drop_method_parameters.use_state_embedding_instead_of_state:
return self.networks['reward_model'].online_network.predict(
states,
outputs=[self.networks['reward_model'].online_network.state_embedding])
outputs=[self.networks['reward_model'].online_network.state_embedding[0]])
else:
return states['observation']
self.embedding = to_embedding
@@ -189,7 +189,7 @@ class DDQNBCQAgent(DQNAgent):
if self.ap.algorithm.action_drop_method_parameters.use_state_embedding_instead_of_state:
self.knn_trees = [AnnoyDictionary(
dict_size=knn_size,
key_width=int(self.networks['reward_model'].online_network.state_embedding.shape[-1]),
key_width=int(self.networks['reward_model'].online_network.state_embedding[0].shape[-1]),
batch_size=knn_size)
for _ in range(len(self.spaces.action.actions))]
else:
+1 -1
View File
@@ -194,7 +194,7 @@ class NECAgent(ValueOptimizationAgent):
)
if self.phase != RunPhase.TEST:
# store the state embedding for inserting it to the DND later
self.current_episode_state_embeddings.append(embedding.squeeze())
self.current_episode_state_embeddings.append(embedding[0].squeeze())
actions_q_values = actions_q_values[0][0]
return actions_q_values
+223
View File
@@ -0,0 +1,223 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from typing import Union
from collections import OrderedDict
import numpy as np
from rl_coach.agents.agent import Agent
from rl_coach.agents.ddpg_agent import DDPGAgent
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.head_parameters import DDPGActorHeadParameters, TD3VHeadParameters
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
from rl_coach.base_parameters import NetworkParameters, AlgorithmParameters, \
AgentParameters, EmbedderScheme
from rl_coach.core_types import ActionInfo, TrainingSteps, Transition
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
from rl_coach.spaces import BoxActionSpace, GoalsSpace
class TD3CriticNetworkParameters(NetworkParameters):
def __init__(self, num_q_networks):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(),
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
self.middleware_parameters = FCMiddlewareParameters(num_streams=num_q_networks)
self.heads_parameters = [TD3VHeadParameters()]
self.optimizer_type = 'Adam'
self.adam_optimizer_beta2 = 0.999
self.optimizer_epsilon = 1e-8
self.batch_size = 100
self.async_training = False
self.learning_rate = 0.001
self.create_target_network = True
self.shared_optimizer = True
self.scale_down_gradients_by_number_of_workers_for_sync_training = False
class TD3ActorNetworkParameters(NetworkParameters):
def __init__(self):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
self.middleware_parameters = FCMiddlewareParameters()
self.heads_parameters = [DDPGActorHeadParameters(batchnorm=False)]
self.optimizer_type = 'Adam'
self.adam_optimizer_beta2 = 0.999
self.optimizer_epsilon = 1e-8
self.batch_size = 100
self.async_training = False
self.learning_rate = 0.001
self.create_target_network = True
self.shared_optimizer = True
self.scale_down_gradients_by_number_of_workers_for_sync_training = False
class TD3AlgorithmParameters(AlgorithmParameters):
"""
:param num_steps_between_copying_online_weights_to_target: (StepMethod)
The number of steps between copying the online network weights to the target network weights.
:param rate_for_copying_weights_to_target: (float)
When copying the online network weights to the target network weights, a soft update will be used, which
weight the new online network weights by rate_for_copying_weights_to_target
:param num_consecutive_playing_steps: (StepMethod)
The number of consecutive steps to act between every two training iterations
:param use_target_network_for_evaluation: (bool)
If set to True, the target network will be used for predicting the actions when choosing actions to act.
Since the target network weights change more slowly, the predicted actions will be more consistent.
:param action_penalty: (float)
The amount by which to penalize the network on high action feature (pre-activation) values.
This can prevent the actions features from saturating the TanH activation function, and therefore prevent the
gradients from becoming very low.
:param clip_critic_targets: (Tuple[float, float] or None)
The range to clip the critic target to in order to prevent overestimation of the action values.
:param use_non_zero_discount_for_terminal_states: (bool)
If set to True, the discount factor will be used for terminal states to bootstrap the next predicted state
values. If set to False, the terminal states reward will be taken as the target return for the network.
"""
def __init__(self):
super().__init__()
self.rate_for_copying_weights_to_target = 0.005
self.use_target_network_for_evaluation = False
self.action_penalty = 0
self.clip_critic_targets = None # expected to be a tuple of the form (min_clip_value, max_clip_value) or None
self.use_non_zero_discount_for_terminal_states = False
self.act_for_full_episodes = True
self.update_policy_every_x_episode_steps = 2
self.num_steps_between_copying_online_weights_to_target = TrainingSteps(self.update_policy_every_x_episode_steps)
self.policy_noise = 0.2
self.noise_clipping = 0.5
self.num_q_networks = 2
class TD3AgentExplorationParameters(AdditiveNoiseParameters):
def __init__(self):
super().__init__()
self.noise_as_percentage_from_action_space = False
class TD3AgentParameters(AgentParameters):
def __init__(self):
td3_algorithm_params = TD3AlgorithmParameters()
super().__init__(algorithm=td3_algorithm_params,
exploration=TD3AgentExplorationParameters(),
memory=EpisodicExperienceReplayParameters(),
networks=OrderedDict([("actor", TD3ActorNetworkParameters()),
("critic",
TD3CriticNetworkParameters(td3_algorithm_params.num_q_networks))]))
@property
def path(self):
return 'rl_coach.agents.td3_agent:TD3Agent'
# Twin Delayed DDPG - https://arxiv.org/pdf/1802.09477.pdf
class TD3Agent(DDPGAgent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
super().__init__(agent_parameters, parent)
self.q_values = self.register_signal("Q")
self.TD_targets_signal = self.register_signal("TD targets")
self.action_signal = self.register_signal("actions")
def learn_from_batch(self, batch):
actor = self.networks['actor']
critic = self.networks['critic']
actor_keys = self.ap.network_wrappers['actor'].input_embedders_parameters.keys()
critic_keys = self.ap.network_wrappers['critic'].input_embedders_parameters.keys()
# TD error = r + discount*max(q_st_plus_1) - q_st
next_actions, actions_mean = actor.parallel_prediction([
(actor.target_network, batch.next_states(actor_keys)),
(actor.online_network, batch.states(actor_keys))
])
# add noise to the next_actions
noise = np.random.normal(0, self.ap.algorithm.policy_noise, next_actions.shape).clip(
-self.ap.algorithm.noise_clipping, self.ap.algorithm.noise_clipping)
next_actions = self.spaces.action.clip_action_to_space(next_actions + noise)
critic_inputs = copy.copy(batch.next_states(critic_keys))
critic_inputs['action'] = next_actions
q_st_plus_1 = critic.target_network.predict(critic_inputs)[2] # output #2 is the min (Q1, Q2)
# calculate the bootstrapped TD targets while discounting terminal states according to
# use_non_zero_discount_for_terminal_states
if self.ap.algorithm.use_non_zero_discount_for_terminal_states:
TD_targets = batch.rewards(expand_dims=True) + self.ap.algorithm.discount * q_st_plus_1
else:
TD_targets = batch.rewards(expand_dims=True) + \
(1.0 - batch.game_overs(expand_dims=True)) * self.ap.algorithm.discount * q_st_plus_1
# clip the TD targets to prevent overestimation errors
if self.ap.algorithm.clip_critic_targets:
TD_targets = np.clip(TD_targets, *self.ap.algorithm.clip_critic_targets)
self.TD_targets_signal.add_sample(TD_targets)
# train the critic
critic_inputs = copy.copy(batch.states(critic_keys))
critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1)
result = critic.train_and_sync_networks(critic_inputs, TD_targets)
total_loss, losses, unclipped_grads = result[:3]
if self.training_iteration % self.ap.algorithm.update_policy_every_x_episode_steps == 0:
# get the gradients of output #3 (=mean of Q1 network) w.r.t the action
critic_inputs = copy.copy(batch.states(critic_keys))
critic_inputs['action'] = actions_mean
action_gradients = critic.online_network.predict(critic_inputs,
outputs=critic.online_network.gradients_wrt_inputs[3]['action'])
# apply the gradients from the critic to the actor
initial_feed_dict = {actor.online_network.gradients_weights_ph[0]: -action_gradients}
gradients = actor.online_network.predict(batch.states(actor_keys),
outputs=actor.online_network.weighted_gradients[0],
initial_feed_dict=initial_feed_dict)
if actor.has_global:
actor.apply_gradients_to_global_network(gradients)
actor.update_online_network()
else:
actor.apply_gradients_to_online_network(gradients)
return total_loss, losses, unclipped_grads
def train(self):
self.ap.algorithm.num_consecutive_training_steps = self.current_episode_steps_counter
return Agent.train(self)
def update_transition_before_adding_to_replay_buffer(self, transition: Transition) -> Transition:
"""
Allows agents to update the transition just before adding it to the replay buffer.
Can be useful for agents that want to tweak the reward, termination signal, etc.
:param transition: the transition to update
:return: the updated transition
"""
transition.game_over = False if self.current_episode_steps_counter ==\
self.parent_level_manager.environment.env._max_episode_steps\
else transition.game_over
return transition
+11
View File
@@ -221,3 +221,14 @@ class SACQHeadParameters(HeadParameters):
super().__init__(parameterized_class_name='SACQHead', activation_function=activation_function, name=name,
dense_layer=dense_layer)
self.network_layers_sizes = layers_sizes
class TD3VHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='td3_v_head_params',
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
loss_weight: float = 1.0, dense_layer=None, initializer='xavier'):
super().__init__(parameterized_class_name="TD3VHead", activation_function=activation_function, name=name,
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
loss_weight=loss_weight)
self.initializer = initializer
@@ -41,10 +41,11 @@ class FCMiddlewareParameters(MiddlewareParameters):
def __init__(self, activation_function='relu',
scheme: Union[List, MiddlewareScheme] = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout_rate: float = 0.0,
name="middleware_fc_embedder", dense_layer=None, is_training=False):
name="middleware_fc_embedder", dense_layer=None, is_training=False, num_streams=1):
super().__init__(parameterized_class_name="FCMiddleware", activation_function=activation_function,
scheme=scheme, batchnorm=batchnorm, dropout_rate=dropout_rate, name=name, dense_layer=dense_layer,
is_training=is_training)
self.num_streams = num_streams
class LSTMMiddlewareParameters(MiddlewareParameters):
@@ -203,7 +203,6 @@ class TensorFlowArchitecture(Architecture):
self._create_gradient_accumulators()
# gradients of the outputs w.r.t. the inputs
# at the moment, this is only used by ddpg
self.gradients_wrt_inputs = [{name: tf.gradients(output, input_ph) for name, input_ph in
self.inputs.items()} for output in self.outputs]
self.gradients_weights_ph = [tf.placeholder('float32', self.outputs[i].shape, 'output_gradient_weights')
@@ -16,6 +16,7 @@ from .sac_head import SACPolicyHead
from .sac_q_head import SACQHead
from .classification_head import ClassificationHead
from .cil_head import RegressionHead
from .td3_v_head import TD3VHead
from .ddpg_v_head import DDPGVHead
__all__ = [
@@ -37,5 +38,6 @@ __all__ = [
'SACQHead',
'ClassificationHead',
'RegressionHead',
'TD3VHead'
'DDPGVHead'
]
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Type
import numpy as np
import tensorflow as tf
@@ -22,7 +21,7 @@ from rl_coach.architectures.tensorflow_components.layers import Dense, convert_l
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list
from rl_coach.architectures.tensorflow_components.utils import squeeze_tensor
# Used to initialize weights for policy and value output layers
def normalized_columns_initializer(std=1.0):
@@ -72,8 +71,9 @@ class Head(object):
:param input_layer: the input to the graph
:return: the output of the last layer and the target placeholder
"""
with tf.variable_scope(self.get_name(), initializer=tf.contrib.layers.xavier_initializer()):
self._build_module(input_layer)
self._build_module(squeeze_tensor(input_layer))
self.output = force_list(self.output)
self.target = force_list(self.target)
@@ -0,0 +1,67 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import VStateValue
from rl_coach.spaces import SpacesDefinition
class TD3VHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense, initializer='xavier'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'td3_v_values_head'
self.return_type = VStateValue
self.loss_type = []
self.initializer = initializer
self.loss = []
self.output = []
def _build_module(self, input_layer):
# Standard V Network
q_outputs = []
self.target = tf.placeholder(tf.float32, shape=(None, 1), name="q_networks_min_placeholder")
for i in range(input_layer.shape[0]): # assuming that the actual size is 2, as there are two critic networks
if self.initializer == 'normalized_columns':
q_outputs.append(self.dense_layer(1)(input_layer[i], name='q_output_{}'.format(i + 1),
kernel_initializer=normalized_columns_initializer(1.0)))
elif self.initializer == 'xavier' or self.initializer is None:
q_outputs.append(self.dense_layer(1)(input_layer[i], name='q_output_{}'.format(i + 1)))
self.output.append(q_outputs[i])
self.loss.append(tf.reduce_mean((self.target-q_outputs[i])**2))
self.output.append(tf.reduce_min(q_outputs, axis=0))
self.output.append(tf.reduce_mean(self.output[0]))
self.loss = sum(self.loss)
tf.losses.add_loss(self.loss)
def __str__(self):
result = [
"Q1 Action-Value Stream",
"\tDense (num outputs = 1)",
"Q2 Action-Value Stream",
"\tDense (num outputs = 1)",
"Min (Q1, Q2)"
]
return '\n'.join(result)
@@ -28,23 +28,28 @@ class FCMiddleware(Middleware):
def __init__(self, activation_function=tf.nn.relu,
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout_rate: float = 0.0,
name="middleware_fc_embedder", dense_layer=Dense, is_training=False):
name="middleware_fc_embedder", dense_layer=Dense, is_training=False, num_streams: int = 1):
super().__init__(activation_function=activation_function, batchnorm=batchnorm,
dropout_rate=dropout_rate, scheme=scheme, name=name, dense_layer=dense_layer,
is_training=is_training)
self.return_type = Middleware_FC_Embedding
self.layers = []
assert(isinstance(num_streams, int) and num_streams >= 1)
self.num_streams = num_streams
def _build_module(self):
self.layers.append(self.input)
self.output = []
for stream_idx in range(self.num_streams):
layers = [self.input]
for idx, layer_params in enumerate(self.layers_params):
self.layers.extend(force_list(
layer_params(self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx),
layers.extend(force_list(
layer_params(layers[-1], name='{}_{}'.format(layer_params.__class__.__name__,
idx + stream_idx * len(self.layers_params)),
is_training=self.is_training)
))
self.output = self.layers[-1]
self.output.append((layers[-1]))
@property
def schemes(self):
@@ -72,3 +77,15 @@ class FCMiddleware(Middleware):
]
}
def __str__(self):
stream = [str(l) for l in self.layers_params]
if self.layers_params:
if self.num_streams > 1:
stream = [''] + ['\t' + l for l in stream]
result = stream * self.num_streams
result[0::len(stream)] = ['Stream {}'.format(i) for i in range(self.num_streams)]
else:
result = stream
return '\n'.join(result)
else:
return 'No layers'
@@ -38,3 +38,10 @@ def get_activation_function(activation_function_string: str):
"Activation function must be one of the following {}. instead it was: {}" \
.format(activation_functions.keys(), activation_function_string)
return activation_functions[activation_function_string]
def squeeze_tensor(tensor):
if tensor.shape[0] == 1:
return tensor[0]
else:
return tensor
+25 -19
View File
@@ -17,7 +17,6 @@
from typing import List
import numpy as np
import scipy.stats
from rl_coach.core_types import RunPhase, ActionType
from rl_coach.exploration_policies.exploration_policy import ContinuousActionExplorationPolicy, ExplorationParameters
@@ -31,8 +30,9 @@ from rl_coach.spaces import ActionSpace, BoxActionSpace
class AdditiveNoiseParameters(ExplorationParameters):
def __init__(self):
super().__init__()
self.noise_percentage_schedule = LinearSchedule(0.1, 0.1, 50000)
self.evaluation_noise_percentage = 0.05
self.noise_schedule = LinearSchedule(0.1, 0.1, 50000)
self.evaluation_noise = 0.05
self.noise_as_percentage_from_action_space = True
@property
def path(self):
@@ -48,17 +48,19 @@ class AdditiveNoise(ContinuousActionExplorationPolicy):
2. Specified by the agents action. In case the agents action is a list with 2 values, the 1st one is assumed to
be the mean of the action, and 2nd is assumed to be its standard deviation.
"""
def __init__(self, action_space: ActionSpace, noise_percentage_schedule: Schedule,
evaluation_noise_percentage: float):
def __init__(self, action_space: ActionSpace, noise_schedule: Schedule,
evaluation_noise: float, noise_as_percentage_from_action_space: bool = True):
"""
:param action_space: the action space used by the environment
:param noise_percentage_schedule: the schedule for the noise variance percentage relative to the absolute range
of the action space
:param evaluation_noise_percentage: the noise variance percentage that will be used during evaluation phases
:param noise_schedule: the schedule for the noise
:param evaluation_noise: the noise variance that will be used during evaluation phases
:param noise_as_percentage_from_action_space: a bool deciding whether the noise is absolute or as a percentage
from the action space
"""
super().__init__(action_space)
self.noise_percentage_schedule = noise_percentage_schedule
self.evaluation_noise_percentage = evaluation_noise_percentage
self.noise_schedule = noise_schedule
self.evaluation_noise = evaluation_noise
self.noise_as_percentage_from_action_space = noise_as_percentage_from_action_space
if not isinstance(action_space, BoxActionSpace):
raise ValueError("Additive noise exploration works only for continuous controls."
@@ -68,19 +70,20 @@ class AdditiveNoise(ContinuousActionExplorationPolicy):
or not np.all(-np.inf < action_space.low) or not np.all(action_space.low < np.inf):
raise ValueError("Additive noise exploration requires bounded actions")
# TODO: allow working with unbounded actions by defining the noise in terms of range and not percentage
def get_action(self, action_values: List[ActionType]) -> ActionType:
# TODO-potential-bug consider separating internally defined stdev and externally defined stdev into 2 policies
# set the current noise percentage
# set the current noise
if self.phase == RunPhase.TEST:
current_noise_precentage = self.evaluation_noise_percentage
current_noise = self.evaluation_noise
else:
current_noise_precentage = self.noise_percentage_schedule.current_value
current_noise = self.noise_schedule.current_value
# scale the noise to the action space range
action_values_std = current_noise_precentage * (self.action_space.high - self.action_space.low)
if self.noise_as_percentage_from_action_space:
action_values_std = current_noise * (self.action_space.high - self.action_space.low)
else:
action_values_std = current_noise
# extract the mean values
if isinstance(action_values, list):
@@ -92,15 +95,18 @@ class AdditiveNoise(ContinuousActionExplorationPolicy):
# step the noise schedule
if self.phase is not RunPhase.TEST:
self.noise_percentage_schedule.step()
self.noise_schedule.step()
# the second element of the list is assumed to be the standard deviation
if isinstance(action_values, list) and len(action_values) > 1:
action_values_std = action_values[1].squeeze()
# add noise to the action means
if self.phase is not RunPhase.TEST:
action = np.random.normal(action_values_mean, action_values_std)
else:
action = action_values_mean
return action
return np.atleast_1d(action)
def get_control_param(self):
return np.ones(self.action_space.shape)*self.noise_percentage_schedule.current_value
return np.ones(self.action_space.shape)*self.noise_schedule.current_value
+1 -1
View File
@@ -32,7 +32,7 @@ class EGreedyParameters(ExplorationParameters):
self.epsilon_schedule = LinearSchedule(0.5, 0.01, 50000)
self.evaluation_epsilon = 0.05
self.continuous_exploration_policy_parameters = AdditiveNoiseParameters()
self.continuous_exploration_policy_parameters.noise_percentage_schedule = LinearSchedule(0.1, 0.1, 50000)
self.continuous_exploration_policy_parameters.noise_schedule = LinearSchedule(0.1, 0.1, 50000)
# for continuous control -
# (see http://www.cs.ubc.ca/~van/papers/2017-TOG-deepLoco/2017-TOG-deepLoco.pdf)
@@ -28,10 +28,11 @@ from rl_coach.spaces import ActionSpace, BoxActionSpace
class TruncatedNormalParameters(ExplorationParameters):
def __init__(self):
super().__init__()
self.noise_percentage_schedule = LinearSchedule(0.1, 0.1, 50000)
self.evaluation_noise_percentage = 0.05
self.noise_schedule = LinearSchedule(0.1, 0.1, 50000)
self.evaluation_noise = 0.05
self.clip_low = 0
self.clip_high = 1
self.noise_as_percentage_from_action_space = True
@property
def path(self):
@@ -49,17 +50,20 @@ class TruncatedNormal(ContinuousActionExplorationPolicy):
When the sampled action is outside of the action bounds given by the user, it is sampled again and again, until it
is within the bounds.
"""
def __init__(self, action_space: ActionSpace, noise_percentage_schedule: Schedule,
evaluation_noise_percentage: float, clip_low: float, clip_high: float):
def __init__(self, action_space: ActionSpace, noise_schedule: Schedule,
evaluation_noise: float, clip_low: float, clip_high: float,
noise_as_percentage_from_action_space: bool = True):
"""
:param action_space: the action space used by the environment
:param noise_percentage_schedule: the schedule for the noise variance percentage relative to the absolute range
of the action space
:param evaluation_noise_percentage: the noise variance percentage that will be used during evaluation phases
:param noise_schedule: the schedule for the noise variance
:param evaluation_noise: the noise variance that will be used during evaluation phases
:param noise_as_percentage_from_action_space: whether to consider the noise as a percentage of the action space
or absolute value
"""
super().__init__(action_space)
self.noise_percentage_schedule = noise_percentage_schedule
self.evaluation_noise_percentage = evaluation_noise_percentage
self.noise_schedule = noise_schedule
self.evaluation_noise = evaluation_noise
self.noise_as_percentage_from_action_space = noise_as_percentage_from_action_space
self.clip_low = clip_low
self.clip_high = clip_high
@@ -71,17 +75,21 @@ class TruncatedNormal(ContinuousActionExplorationPolicy):
or not np.all(-np.inf < action_space.low) or not np.all(action_space.low < np.inf):
raise ValueError("Additive noise exploration requires bounded actions")
# TODO: allow working with unbounded actions by defining the noise in terms of range and not percentage
def get_action(self, action_values: List[ActionType]) -> ActionType:
# set the current noise percentage
# set the current noise
if self.phase == RunPhase.TEST:
current_noise_precentage = self.evaluation_noise_percentage
current_noise = self.evaluation_noise
else:
current_noise_precentage = self.noise_percentage_schedule.current_value
current_noise = self.noise_schedule.current_value
# scale the noise to the action space range
action_values_std = current_noise_precentage * (self.action_space.high - self.action_space.low)
if self.noise_as_percentage_from_action_space:
action_values_std = current_noise * (self.action_space.high - self.action_space.low)
else:
action_values_std = current_noise
# scale the noise to the action space range
action_values_std = current_noise * (self.action_space.high - self.action_space.low)
# extract the mean values
if isinstance(action_values, list):
@@ -93,7 +101,7 @@ class TruncatedNormal(ContinuousActionExplorationPolicy):
# step the noise schedule
if self.phase is not RunPhase.TEST:
self.noise_percentage_schedule.step()
self.noise_schedule.step()
# the second element of the list is assumed to be the standard deviation
if isinstance(action_values, list) and len(action_values) > 1:
action_values_std = action_values[1].squeeze()
@@ -107,4 +115,4 @@ class TruncatedNormal(ContinuousActionExplorationPolicy):
return action
def get_control_param(self):
return np.ones(self.action_space.shape)*self.noise_percentage_schedule.current_value
return np.ones(self.action_space.shape)*self.noise_schedule.current_value

Some files were not shown because too many files have changed in this diff Show More