SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
@@ -25,11 +25,11 @@ coach -p CartPole_DQN -r
|
||||
<img src="img/doom_health.gif" alt="Doom Health Gathering"/> <img src="img/minitaur.gif" alt="PyBullet Minitaur" width = "249" height ="200"/> <img src="img/ant.gif" alt="Gym Extensions Ant"/>
|
||||
<br><br>
|
||||
|
||||
Blog posts from the Intel® AI website:
|
||||
* [Release 0.8.0](https://ai.intel.com/reinforcement-learning-coach-intel/) (initial release)
|
||||
* [Release 0.9.0](https://ai.intel.com/reinforcement-learning-coach-carla-qr-dqn/)
|
||||
* [Release 0.10.0](https://ai.intel.com/introducing-reinforcement-learning-coach-0-10-0/)
|
||||
* [Release 0.11.0](https://ai.intel.com/rl-coach-data-science-at-scale) (current release)
|
||||
* [Release 0.11.0](https://ai.intel.com/rl-coach-data-science-at-scale)
|
||||
* Release 0.12.0 (current release)
|
||||
|
||||
Contacting the Coach development team is also possible through the email [coach@intel.com](coach@intel.com)
|
||||
|
||||
@@ -277,6 +277,7 @@ dashboard
|
||||
* [Clipped Proximal Policy Optimization (CPPO)](https://arxiv.org/pdf/1707.06347.pdf) | **Multi Worker Single Node** ([code](rl_coach/agents/clipped_ppo_agent.py))
|
||||
* [Generalized Advantage Estimation (GAE)](https://arxiv.org/abs/1506.02438) ([code](rl_coach/agents/actor_critic_agent.py#L86))
|
||||
* [Sample Efficient Actor-Critic with Experience Replay (ACER)](https://arxiv.org/abs/1611.01224) | **Multi Worker Single Node** ([code](rl_coach/agents/acer_agent.py))
|
||||
* [Soft Actor-Critic (SAC)](https://arxiv.org/abs/1801.01290) ([code](rl_coach/agents/soft_actor_critic_agent.py))
|
||||
|
||||
### General Agents
|
||||
* [Direct Future Prediction (DFP)](https://arxiv.org/abs/1611.01779) | **Multi Worker Single Node** ([code](rl_coach/agents/dfp_agent.py))
|
||||
|
||||
@@ -37,6 +37,7 @@ The environments that were used for testing include:
|
||||
|**[ACER](acer)** |  |Atari | |
|
||||
|**[Clipped PPO](clipped_ppo)** |  |Mujoco | |
|
||||
|**[DDPG](ddpg)** |  |Mujoco | |
|
||||
|**[SAC](sac)** |  |Mujoco | |
|
||||
|**[NEC](nec)** |  |Atari | |
|
||||
|**[HER](ddpg_her)** |  |Fetch | |
|
||||
|**[DFP](dfp)** |  |Doom | Doom Battle was not verified |
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Clipped PPO
|
||||
|
||||
Each experiment uses 3 seeds and is trained for 10k environment steps.
|
||||
Each experiment uses 3 seeds and is trained for 10M environment steps.
|
||||
The parameters used for Clipped PPO are the same parameters as described in the [original paper](https://arxiv.org/abs/1707.06347).
|
||||
|
||||
### Inverted Pendulum Clipped PPO - single worker
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# DDPG
|
||||
|
||||
Each experiment uses 3 seeds and is trained for 2k environment steps.
|
||||
Each experiment uses 3 seeds and is trained for 2M environment steps.
|
||||
The parameters used for DDPG are the same parameters as described in the [original paper](https://arxiv.org/abs/1509.02971).
|
||||
|
||||
### Inverted Pendulum DDPG - single worker
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
# Soft Actor Critic
|
||||
|
||||
Each experiment uses 3 seeds and is trained for 3M environment steps.
|
||||
The parameters used for SAC are the same parameters as described in the [original paper](https://arxiv.org/abs/1801.01290).
|
||||
|
||||
### Inverted Pendulum SAC - single worker
|
||||
|
||||
```bash
|
||||
coach -p Mujoco_SAC -lvl inverted_pendulum
|
||||
```
|
||||
|
||||
<img src="inverted_pendulum_sac.png" alt="Inverted Pendulum SAC" width="800"/>
|
||||
|
||||
|
||||
### Hopper Clipped SAC - single worker
|
||||
|
||||
```bash
|
||||
coach -p Mujoco_SAC -lvl hopper
|
||||
```
|
||||
|
||||
<img src="hopper_sac.png" alt="Hopper SAC" width="800"/>
|
||||
|
||||
|
||||
### Half Cheetah Clipped SAC - single worker
|
||||
|
||||
```bash
|
||||
coach -p Mujoco_SAC -lvl half_cheetah
|
||||
```
|
||||
|
||||
<img src="half_cheetah_sac.png" alt="Half Cheetah SAC" width="800"/>
|
||||
|
||||
|
||||
### Walker 2D Clipped SAC - single worker
|
||||
|
||||
```bash
|
||||
coach -p Mujoco_SAC -lvl walker2d
|
||||
```
|
||||
|
||||
<img src="walker2d_sac.png" alt="Walker 2D SAC" width="800"/>
|
||||
|
||||
|
||||
### Humanoid Clipped SAC - single worker
|
||||
|
||||
```bash
|
||||
coach -p Mujoco_SAC -lvl humanoid
|
||||
```
|
||||
|
||||
<img src="humanoid_sac.png" alt="Humanoid SAC" width="800"/>
|
||||
|
After Width: | Height: | Size: 66 KiB |
|
After Width: | Height: | Size: 97 KiB |
|
After Width: | Height: | Size: 90 KiB |
|
After Width: | Height: | Size: 49 KiB |
|
After Width: | Height: | Size: 77 KiB |
|
Before Width: | Height: | Size: 51 KiB After Width: | Height: | Size: 59 KiB |
|
After Width: | Height: | Size: 109 KiB |
@@ -179,6 +179,7 @@
|
||||
<ul><li><a href="rl_coach/agents/acer_agent.html">rl_coach.agents.acer_agent</a></li>
|
||||
<li><a href="rl_coach/agents/actor_critic_agent.html">rl_coach.agents.actor_critic_agent</a></li>
|
||||
<li><a href="rl_coach/agents/agent.html">rl_coach.agents.agent</a></li>
|
||||
<li><a href="rl_coach/agents/agent_interface.html">rl_coach.agents.agent_interface</a></li>
|
||||
<li><a href="rl_coach/agents/bc_agent.html">rl_coach.agents.bc_agent</a></li>
|
||||
<li><a href="rl_coach/agents/categorical_dqn_agent.html">rl_coach.agents.categorical_dqn_agent</a></li>
|
||||
<li><a href="rl_coach/agents/cil_agent.html">rl_coach.agents.cil_agent</a></li>
|
||||
@@ -195,6 +196,7 @@
|
||||
<li><a href="rl_coach/agents/ppo_agent.html">rl_coach.agents.ppo_agent</a></li>
|
||||
<li><a href="rl_coach/agents/qr_dqn_agent.html">rl_coach.agents.qr_dqn_agent</a></li>
|
||||
<li><a href="rl_coach/agents/rainbow_dqn_agent.html">rl_coach.agents.rainbow_dqn_agent</a></li>
|
||||
<li><a href="rl_coach/agents/soft_actor_critic_agent.html">rl_coach.agents.soft_actor_critic_agent</a></li>
|
||||
<li><a href="rl_coach/agents/value_optimization_agent.html">rl_coach.agents.value_optimization_agent</a></li>
|
||||
<li><a href="rl_coach/architectures/architecture.html">rl_coach.architectures.architecture</a></li>
|
||||
<li><a href="rl_coach/architectures/network_wrapper.html">rl_coach.architectures.network_wrapper</a></li>
|
||||
|
||||
@@ -248,7 +248,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_gradient_updates</span> <span class="o">=</span> <span class="mi">5000</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ratio_of_replay</span> <span class="o">=</span> <span class="mi">4</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_to_start_replay</span> <span class="o">=</span> <span class="mi">10000</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.99</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.01</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">importance_weight_truncation</span> <span class="o">=</span> <span class="mf">10.0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_trust_region_optimization</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_KL_divergence</span> <span class="o">=</span> <span class="mf">1.0</span>
|
||||
|
||||
@@ -214,6 +214,9 @@
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">Signal</span><span class="p">,</span> <span class="n">force_list</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">dynamic_import_and_instantiate_module_from_params</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.backend.memory_impl</span> <span class="k">import</span> <span class="n">get_memory_backend</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">TimeTypes</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.off_policy_evaluators.ope_manager</span> <span class="k">import</span> <span class="n">OpeManager</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">PickledReplayBuffer</span><span class="p">,</span> <span class="n">CsvDataset</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="Agent"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent">[docs]</a><span class="k">class</span> <span class="nc">Agent</span><span class="p">(</span><span class="n">AgentInterface</span><span class="p">):</span>
|
||||
@@ -222,6 +225,15 @@
|
||||
<span class="sd"> :param agent_parameters: A AgentParameters class instance with all the agent parameters</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="c1"># use seed</span>
|
||||
<span class="k">if</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
|
||||
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="c1"># we need to seed the RNG since the different processes are initialized with the same parent seed</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">()</span>
|
||||
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span> <span class="o">=</span> <span class="n">agent_parameters</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">task_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">task_index</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_chief</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">task_id</span> <span class="o">==</span> <span class="mi">0</span>
|
||||
@@ -229,10 +241,10 @@
|
||||
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">shared_memory</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_memory</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shared_memory_scratchpad</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">shared_memory_scratchpad</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">name</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">parent</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">full_name_id</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">full_name_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
|
||||
<span class="c1"># TODO this needs to be sorted out. Why the duplicates for the agent's name?</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">full_name_id</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">full_name_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">name</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">task_parameters</span><span class="p">)</span> <span class="o">==</span> <span class="n">DistributedTaskParameters</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"Creating agent - name: </span><span class="si">{}</span><span class="s2"> task id: </span><span class="si">{}</span><span class="s2"> (may take up to 30 seconds due to "</span>
|
||||
@@ -264,9 +276,17 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">set_memory_backend</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"Loading replay buffer from pickle. Pickle path: </span><span class="si">{}</span><span class="s2">"</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">PickledReplayBuffer</span><span class="p">):</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"Loading a pickled replay buffer. Pickled file path: </span><span class="si">{}</span><span class="s2">"</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_pickled</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">,</span> <span class="n">CsvDataset</span><span class="p">):</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"Loading a replay buffer from a CSV file. CSV file path: </span><span class="si">{}</span><span class="s2">"</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="o">.</span><span class="n">filepath</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_csv</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'Trying to load a replay buffer using an unsupported method - </span><span class="si">{}</span><span class="s1">. '</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">agent_parameters</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">load_memory_from_file_path</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_memory</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_chief</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shared_memory_scratchpad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory_lookup_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">)</span>
|
||||
@@ -327,6 +347,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">running_reward</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_target_network_update_step</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_phase_step</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
@@ -364,14 +385,9 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">discounted_return</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Discounted Return'</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">in_action_space</span><span class="p">,</span> <span class="n">GoalsSpace</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">distance_from_goal</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Distance From Goal'</span><span class="p">,</span> <span class="n">dump_one_value_per_step</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="c1"># use seed</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
|
||||
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="c1"># we need to seed the RNG since the different processes are initialized with the same parent seed</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">()</span>
|
||||
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># batch rl</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ope_manager</span> <span class="o">=</span> <span class="n">OpeManager</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_batch_rl_training</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">parent</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'LevelManager'</span><span class="p">:</span>
|
||||
@@ -408,6 +424,7 @@
|
||||
<span class="nb">format</span><span class="p">(</span><span class="n">graph_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">level_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">agent_full_id</span><span class="o">=</span><span class="s1">'.'</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">full_name_id</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'/'</span><span class="p">)))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_index_name</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_logger_filenames</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">experiment_path</span><span class="p">,</span> <span class="n">logger_prefix</span><span class="o">=</span><span class="n">logger_prefix</span><span class="p">,</span>
|
||||
<span class="n">add_timestamp</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">task_id</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">task_id</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_in_episode_signals</span><span class="p">:</span>
|
||||
@@ -561,13 +578,17 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">accumulated_shaped_rewards_across_evaluation_episodes</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_successes_across_evaluation_episodes</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_evaluation_episodes_completed</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">==</span> <span class="s2">"high"</span><span class="p">:</span>
|
||||
|
||||
<span class="c1"># TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back</span>
|
||||
<span class="c1"># if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"</span><span class="si">{}</span><span class="s2">: Starting evaluation phase"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">))</span>
|
||||
|
||||
<span class="k">elif</span> <span class="n">ending_evaluation</span><span class="p">:</span>
|
||||
<span class="c1"># we write to the next episode, because it could be that the current episode was already written</span>
|
||||
<span class="c1"># to disk and then we won't write it again</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_current_time</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="n">evaluation_reward</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulated_rewards_across_evaluation_episodes</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_evaluation_episodes_completed</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span>
|
||||
<span class="s1">'Evaluation Reward'</span><span class="p">,</span> <span class="n">evaluation_reward</span><span class="p">)</span>
|
||||
@@ -577,9 +598,11 @@
|
||||
<span class="n">success_rate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_successes_across_evaluation_episodes</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_evaluation_episodes_completed</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span>
|
||||
<span class="s2">"Success Rate"</span><span class="p">,</span>
|
||||
<span class="n">success_rate</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">==</span> <span class="s2">"high"</span><span class="p">:</span>
|
||||
<span class="n">success_rate</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back</span>
|
||||
<span class="c1"># if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_title</span><span class="p">(</span><span class="s2">"</span><span class="si">{}</span><span class="s2">: Finished evaluation phase. Success rate = </span><span class="si">{}</span><span class="s2">, Avg Total Reward = </span><span class="si">{}</span><span class="s2">"</span>
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">round</span><span class="p">(</span><span class="n">success_rate</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">round</span><span class="p">(</span><span class="n">evaluation_reward</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span></div>
|
||||
|
||||
@@ -652,8 +675,11 @@
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="c1"># log all the signals to file</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span><span class="p">)</span>
|
||||
<span class="n">current_time</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_current_time</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="n">current_time</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Training Iter'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Episode #'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Epoch'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'In Heatup'</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'ER #Transitions'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'num_transitions'</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'ER #Episodes'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'length'</span><span class="p">))</span>
|
||||
@@ -666,12 +692,17 @@
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Update Target Network'</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">update_wall_clock_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">update_wall_clock_time</span><span class="p">(</span><span class="n">current_time</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
|
||||
<span class="c1"># The following signals are created with meaningful values only when an evaluation phase is completed.</span>
|
||||
<span class="c1"># Creating with default NaNs for any HEATUP/TRAIN/TEST episode which is not the last in an evaluation phase</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Evaluation Reward'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Shaped Evaluation Reward'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Success Rate'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Inverse Propensity Score'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Direct Method Reward'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Doubly Robust'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Sequential Doubly Robust'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">signal</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">episode_signals</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s2">"</span><span class="si">{}</span><span class="s2">/Mean"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">signal</span><span class="o">.</span><span class="n">name</span><span class="p">),</span> <span class="n">signal</span><span class="o">.</span><span class="n">get_mean</span><span class="p">())</span>
|
||||
@@ -680,8 +711,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s2">"</span><span class="si">{}</span><span class="s2">/Min"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">signal</span><span class="o">.</span><span class="n">name</span><span class="p">),</span> <span class="n">signal</span><span class="o">.</span><span class="n">get_min</span><span class="p">())</span>
|
||||
|
||||
<span class="c1"># dump</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_signals_to_csv_every_x_episodes</span> <span class="o">==</span> <span class="mi">0</span> \
|
||||
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_signals_to_csv_every_x_episodes</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">dump_output_csv</span><span class="p">()</span></div>
|
||||
|
||||
<div class="viewcode-block" id="Agent.handle_episode_ended"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.handle_episode_ended">[docs]</a> <span class="k">def</span> <span class="nf">handle_episode_ended</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
@@ -717,10 +747,13 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_reward_in_current_episode</span> <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">reward</span><span class="o">.</span><span class="n">reward_success_threshold</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_successes_across_evaluation_episodes</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_csv</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_csv</span> <span class="ow">and</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span> <span class="o">==</span> <span class="n">TimeTypes</span><span class="o">.</span><span class="n">EpisodeNumber</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_log</span><span class="p">()</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">==</span> <span class="s2">"high"</span><span class="p">:</span>
|
||||
<span class="c1"># TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back</span>
|
||||
<span class="c1"># if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">log_to_screen</span><span class="p">()</span></div>
|
||||
|
||||
<div class="viewcode-block" id="Agent.reset_internal_state"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.reset_internal_state">[docs]</a> <span class="k">def</span> <span class="nf">reset_internal_state</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
@@ -831,18 +864,25 @@
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_train</span><span class="p">():</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">training_step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">):</span>
|
||||
<span class="c1"># TODO: this should be network dependent</span>
|
||||
<span class="n">network_parameters</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="o">.</span><span class="n">values</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># At the moment we only support a single batch size for all the networks</span>
|
||||
<span class="n">networks_parameters</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
|
||||
<span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">==</span> <span class="n">networks_parameters</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">for</span> <span class="n">net</span> <span class="ow">in</span> <span class="n">networks_parameters</span><span class="p">)</span>
|
||||
|
||||
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">networks_parameters</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
|
||||
<span class="c1"># we either go sequentially through the entire replay buffer in the batch RL mode,</span>
|
||||
<span class="c1"># or sample randomly for the basic RL case.</span>
|
||||
<span class="n">training_schedule</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'get_shuffled_data_generator'</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span> <span class="k">if</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">is_batch_rl_training</span> <span class="k">else</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'sample'</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span>
|
||||
<span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">)]</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">training_schedule</span><span class="p">:</span>
|
||||
<span class="c1"># update counters</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
|
||||
<span class="c1"># sample a batch and train on it</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'sample'</span><span class="p">,</span> <span class="n">network_parameters</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
@@ -853,15 +893,19 @@
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">learn_from_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
||||
<span class="n">loss</span> <span class="o">+=</span> <span class="n">total_loss</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">unclipped_grads</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">unclipped_grads</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># TODO: the learning rate decay should be done through the network instead of here</span>
|
||||
<span class="c1"># TODO: this only deals with the main network (if exists), need to do the same for other networks</span>
|
||||
<span class="c1"># for instance, for DDPG, the LR signal is currently not shown. Probably should be done through the</span>
|
||||
<span class="c1"># network directly instead of here</span>
|
||||
<span class="c1"># decay learning rate</span>
|
||||
<span class="k">if</span> <span class="n">network_parameters</span><span class="o">.</span><span class="n">learning_rate_decay_rate</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="s1">'main'</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="ow">and</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">learning_rate_decay_rate</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">curr_learning_rate</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">current_learning_rate</span><span class="p">))</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">curr_learning_rate</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">network_parameters</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">curr_learning_rate</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">networks_parameters</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">any</span><span class="p">([</span><span class="n">network</span><span class="o">.</span><span class="n">has_target</span> <span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">()])</span> \
|
||||
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_update_online_weights_to_target</span><span class="p">():</span>
|
||||
@@ -877,6 +921,12 @@
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">imitation</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">log_to_screen</span><span class="p">()</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_csv</span> <span class="ow">and</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span> <span class="o">==</span> <span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span>
|
||||
<span class="c1"># in BatchRL, or imitation learning, the agent never acts, so we have to get the stats out here.</span>
|
||||
<span class="c1"># we dump the data out every epoch</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_log</span><span class="p">()</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
@@ -919,10 +969,11 @@
|
||||
|
||||
<span class="k">return</span> <span class="n">batches_dict</span></div>
|
||||
|
||||
<div class="viewcode-block" id="Agent.act"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.act">[docs]</a> <span class="k">def</span> <span class="nf">act</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionInfo</span><span class="p">:</span>
|
||||
<div class="viewcode-block" id="Agent.act"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.act">[docs]</a> <span class="k">def</span> <span class="nf">act</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">ActionType</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionInfo</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Given the agents current knowledge, decide on the next action to apply to the environment</span>
|
||||
|
||||
<span class="sd"> :param action: An action to take, overriding whatever the current policy is</span>
|
||||
<span class="sd"> :return: An ActionInfo object, which contains the action and any additional info from the action decision process</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
@@ -935,9 +986,10 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_steps_counter</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
|
||||
<span class="c1"># decide on the action</span>
|
||||
<span class="k">if</span> <span class="n">action</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">heatup_using_network_decisions</span><span class="p">:</span>
|
||||
<span class="c1"># random action</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">sample_with_info</span><span class="p">()</span>
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">sample_with_info</span><span class="p">()</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="c1"># informed action</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
@@ -947,8 +999,15 @@
|
||||
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">choose_action</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">choose_action</span><span class="p">(</span><span class="n">curr_state</span><span class="p">)</span>
|
||||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action</span><span class="p">,</span> <span class="n">ActionInfo</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span> <span class="o">=</span> <span class="n">action</span>
|
||||
|
||||
<span class="c1"># output filters are explicitly applied after recording self.last_action_info. This is</span>
|
||||
<span class="c1"># because the output filters may change the representation of the action so that the agent</span>
|
||||
<span class="c1"># can no longer use the transition in it's replay buffer. It is possible that these filters</span>
|
||||
<span class="c1"># could be moved to the environment instead, but they are here now for historical reasons.</span>
|
||||
<span class="n">filtered_action_info</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">filtered_action_info</span></div>
|
||||
@@ -1030,24 +1089,22 @@
|
||||
<span class="c1"># make agent specific changes to the transition if needed</span>
|
||||
<span class="n">transition</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_transition_before_adding_to_replay_buffer</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># merge the intrinsic reward in</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">scale_external_reward_by_intrinsic_reward_value</span><span class="p">:</span>
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="o">=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span><span class="o">.</span><span class="n">action_intrinsic_reward</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="o">=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span><span class="o">.</span><span class="n">action_intrinsic_reward</span>
|
||||
|
||||
<span class="c1"># sum up the total shaped reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_shaped_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">env_response</span><span class="o">.</span><span class="n">reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shaped_reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">env_response</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># add action info to transition</span>
|
||||
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">parent</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'CompositeAgent'</span><span class="p">:</span>
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">add_info</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">last_action_info</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">transition</span><span class="o">.</span><span class="n">add_info</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">env_response</span><span class="o">.</span><span class="n">reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">env_response</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">observe_transition</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span></div>
|
||||
|
||||
<span class="k">def</span> <span class="nf">observe_transition</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">):</span>
|
||||
<span class="c1"># sum up the total shaped reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_shaped_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shaped_reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># create and store the transition</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">in</span> <span class="p">[</span><span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">,</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span><span class="p">]:</span>
|
||||
<span class="c1"># for episodic memories we keep the transitions in a local buffer until the episode is ended.</span>
|
||||
@@ -1060,7 +1117,7 @@
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_in_episode_signals</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_step_in_episode_log</span><span class="p">()</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">transition</span><span class="o">.</span><span class="n">game_over</span></div>
|
||||
<span class="k">return</span> <span class="n">transition</span><span class="o">.</span><span class="n">game_over</span>
|
||||
|
||||
<div class="viewcode-block" id="Agent.post_training_commands"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.post_training_commands">[docs]</a> <span class="k">def</span> <span class="nf">post_training_commands</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
@@ -1145,60 +1202,6 @@
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">sync</span><span class="p">()</span></div>
|
||||
|
||||
<span class="c1"># TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create</span>
|
||||
<span class="c1"># an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]</span>
|
||||
<div class="viewcode-block" id="Agent.emulate_observe_on_trainer"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.emulate_observe_on_trainer">[docs]</a> <span class="k">def</span> <span class="nf">emulate_observe_on_trainer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">:</span> <span class="n">Transition</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> This emulates the observe using the transition obtained from the rollout worker on the training worker</span>
|
||||
<span class="sd"> in case of distributed training.</span>
|
||||
<span class="sd"> Given a response from the environment, distill the observation from it and store it for later use.</span>
|
||||
<span class="sd"> The response should be a dictionary containing the performed action, the new observation and measurements,</span>
|
||||
<span class="sd"> the reward, a game over flag and any additional information necessary.</span>
|
||||
<span class="sd"> :return:</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="c1"># sum up the total shaped reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_shaped_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_reward_in_current_episode</span> <span class="o">+=</span> <span class="n">transition</span><span class="o">.</span><span class="n">reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shaped_reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reward</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># create and store the transition</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">in</span> <span class="p">[</span><span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span><span class="p">,</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span><span class="p">]:</span>
|
||||
<span class="c1"># for episodic memories we keep the transitions in a local buffer until the episode is ended.</span>
|
||||
<span class="c1"># for regular memories we insert the transitions directly to the memory</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">EpisodicExperienceReplay</span><span class="p">)</span> \
|
||||
<span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">store_transitions_only_when_episodes_are_terminated</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'store'</span><span class="p">,</span> <span class="n">transition</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">visualization</span><span class="o">.</span><span class="n">dump_in_episode_signals</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_step_in_episode_log</span><span class="p">()</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">transition</span><span class="o">.</span><span class="n">game_over</span></div>
|
||||
|
||||
<span class="c1"># TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create</span>
|
||||
<span class="c1"># an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]</span>
|
||||
<div class="viewcode-block" id="Agent.emulate_act_on_trainer"><a class="viewcode-back" href="../../../components/agents/index.html#rl_coach.agents.agent.Agent.emulate_act_on_trainer">[docs]</a> <span class="k">def</span> <span class="nf">emulate_act_on_trainer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">:</span> <span class="n">Transition</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionInfo</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> This emulates the act using the transition obtained from the rollout worker on the training worker</span>
|
||||
<span class="sd"> in case of distributed training.</span>
|
||||
<span class="sd"> Given the agents current knowledge, decide on the next action to apply to the environment</span>
|
||||
<span class="sd"> :return: an action and a dictionary containing any additional info from the action decision process</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TRAIN</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_playing_steps</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="c1"># This agent never plays while training (e.g. behavioral cloning)</span>
|
||||
<span class="k">return</span> <span class="kc">None</span>
|
||||
|
||||
<span class="c1"># count steps (only when training or if we are in the evaluation worker)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">task_parameters</span><span class="o">.</span><span class="n">evaluate_only</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_steps_counter</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span> <span class="o">=</span> <span class="n">transition</span><span class="o">.</span><span class="n">action</span>
|
||||
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_action_info</span></div>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_success_rate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_successes_across_evaluation_episodes</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_evaluation_episodes_completed</span>
|
||||
|
||||
@@ -1213,7 +1216,16 @@
|
||||
<span class="n">savers</span> <span class="o">=</span> <span class="n">SaverCollection</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">network</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="n">savers</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">network</span><span class="o">.</span><span class="n">collect_savers</span><span class="p">(</span><span class="n">parent_path_suffix</span><span class="p">))</span>
|
||||
<span class="k">return</span> <span class="n">savers</span></div></div>
|
||||
<span class="k">return</span> <span class="n">savers</span></div>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_current_time</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">pass</span>
|
||||
<span class="k">return</span> <span class="p">{</span>
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">EpisodeNumber</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode</span><span class="p">,</span>
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">TrainingIteration</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span><span class="p">,</span>
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">EnvironmentSteps</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_steps_counter</span><span class="p">,</span>
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">WallClockTime</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">get_current_wall_clock_time</span><span class="p">(),</span>
|
||||
<span class="n">TimeTypes</span><span class="o">.</span><span class="n">Epoch</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">parent_level_manager</span><span class="o">.</span><span class="n">parent_graph_manager</span><span class="o">.</span><span class="n">time_metric</span><span class="p">]</span></div>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,387 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.agent_interface — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/data_stores/index.html">Data Stores</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memory_backends/index.html">Memory Backends</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/orchestrators/index.html">Orchestrators</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.agent_interface</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.agent_interface</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">EnvResponse</span><span class="p">,</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">RunPhase</span><span class="p">,</span> <span class="n">PredictionType</span><span class="p">,</span> <span class="n">ActionType</span><span class="p">,</span> <span class="n">Transition</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.saver</span> <span class="k">import</span> <span class="n">SaverCollection</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">AgentInterface</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">HEATUP</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_parent</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spaces</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">parent</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get the parent class of the agent</span>
|
||||
<span class="sd"> :return: the current phase</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parent</span>
|
||||
|
||||
<span class="nd">@parent</span><span class="o">.</span><span class="n">setter</span>
|
||||
<span class="k">def</span> <span class="nf">parent</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Change the parent class of the agent</span>
|
||||
<span class="sd"> :param val: the new parent</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_parent</span> <span class="o">=</span> <span class="n">val</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">phase</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">RunPhase</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get the phase of the agent</span>
|
||||
<span class="sd"> :return: the current phase</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_phase</span>
|
||||
|
||||
<span class="nd">@phase</span><span class="o">.</span><span class="n">setter</span>
|
||||
<span class="k">def</span> <span class="nf">phase</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="n">RunPhase</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Change the phase of the agent</span>
|
||||
<span class="sd"> :param val: the new phase</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_phase</span> <span class="o">=</span> <span class="n">val</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">reset_internal_state</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Reset the episode parameters for the agent</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">List</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Train the agents network</span>
|
||||
<span class="sd"> :return: The loss of the training</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">act</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionInfo</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get a decision of the next action to take.</span>
|
||||
<span class="sd"> The action is dependent on the current state which the agent holds from resetting the environment or from</span>
|
||||
<span class="sd"> the observe function.</span>
|
||||
<span class="sd"> :return: A tuple containing the actual action and additional info on the action</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">observe</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">env_response</span><span class="p">:</span> <span class="n">EnvResponse</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Gets a response from the environment.</span>
|
||||
<span class="sd"> Processes this information for later use. For example, create a transition and store it in memory.</span>
|
||||
<span class="sd"> The action info (a class containing any info the agent wants to store regarding its action decision process) is</span>
|
||||
<span class="sd"> stored by the agent itself when deciding on the action.</span>
|
||||
<span class="sd"> :param env_response: a EnvResponse containing the response from the environment</span>
|
||||
<span class="sd"> :return: a done signal which is based on the agent knowledge. This can be different from the done signal from</span>
|
||||
<span class="sd"> the environment. For example, an agent can decide to finish the episode each time it gets some</span>
|
||||
<span class="sd"> intrinsic reward</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">save_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">checkpoint_prefix</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Save the model of the agent to the disk. This can contain the network parameters, the memory of the agent, etc.</span>
|
||||
<span class="sd"> :param checkpoint_prefix: The prefix of the checkpoint file to save</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">prediction_type</span><span class="p">:</span> <span class="n">PredictionType</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get a prediction from the agent with regard to the requested prediction_type. If the agent cannot predict this</span>
|
||||
<span class="sd"> type of prediction_type, or if there is more than possible way to do so, raise a ValueException.</span>
|
||||
<span class="sd"> :param states:</span>
|
||||
<span class="sd"> :param prediction_type:</span>
|
||||
<span class="sd"> :return: the agent's prediction</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">set_incoming_directive</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Pass a higher level command (directive) to the agent.</span>
|
||||
<span class="sd"> For example, a higher level agent can set the goal of the agent.</span>
|
||||
<span class="sd"> :param action: the directive to pass to the agent</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">collect_savers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">parent_path_suffix</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">SaverCollection</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Collect all of agent savers</span>
|
||||
<span class="sd"> :param parent_path_suffix: path suffix of the parent of the agent</span>
|
||||
<span class="sd"> (could be name of level manager or composite agent)</span>
|
||||
<span class="sd"> :return: collection of all agent savers</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">handle_episode_ended</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Make any changes needed when each episode is ended.</span>
|
||||
<span class="sd"> This includes incrementing counters, updating full episode dependent values, updating logs, etc.</span>
|
||||
<span class="sd"> This function is called right after each episode is ended.</span>
|
||||
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">run_off_policy_evaluation</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Run off-policy evaluation estimators to evaluate the trained policy performance against a dataset.</span>
|
||||
<span class="sd"> Should only be implemented for off-policy RL algorithms.</span>
|
||||
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/language_data.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
@@ -264,8 +264,8 @@
|
||||
<span class="c1"># prediction's format is (batch,actions,atoms)</span>
|
||||
<span class="k">def</span> <span class="nf">get_all_q_values_for_states</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">:</span> <span class="n">StateType</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">requires_action_values</span><span class="p">():</span>
|
||||
<span class="n">prediction</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">)</span>
|
||||
<span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">prediction</span><span class="p">)</span>
|
||||
<span class="n">q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_prediction</span><span class="p">(</span><span class="n">states</span><span class="p">,</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">q_values</span><span class="p">])</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">q_values</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">return</span> <span class="n">q_values</span>
|
||||
@@ -280,11 +280,14 @@
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="c1"># add Q value samples for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># select the optimal actions for the next state</span>
|
||||
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">distributional_q_st_plus_1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||||
|
||||
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
|
||||
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># an alternative to the for loop. 3.7x perf improvement vs. the same code done with for looping.</span>
|
||||
<span class="c1"># only 10% speedup overall - leaving commented out as the code is not as clear.</span>
|
||||
@@ -297,7 +300,7 @@
|
||||
<span class="c1"># bj_ = (tzj_ - self.z_values[0]) / (self.z_values[1] - self.z_values[0])</span>
|
||||
<span class="c1"># u_ = (np.ceil(bj_)).astype(int)</span>
|
||||
<span class="c1"># l_ = (np.floor(bj_)).astype(int)</span>
|
||||
<span class="c1"># m_ = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))</span>
|
||||
<span class="c1"># m_ = np.zeros((batch.size, self.z_values.size))</span>
|
||||
<span class="c1"># np.add.at(m_, [batches, l_],</span>
|
||||
<span class="c1"># np.transpose(distributional_q_st_plus_1[batches, target_actions], (1, 0)) * (u_ - bj_))</span>
|
||||
<span class="c1"># np.add.at(m_, [batches, u_],</span>
|
||||
|
||||
@@ -387,6 +387,8 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">likelihood_ratio</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">clipped_likelihood_ratio</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on</span>
|
||||
<span class="c1"># some of the data</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)):</span>
|
||||
<span class="n">start</span> <span class="o">=</span> <span class="n">i</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
<span class="n">end</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
|
||||
@@ -326,13 +326,13 @@
|
||||
|
||||
<span class="n">network_inputs</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">)</span>
|
||||
<span class="n">network_inputs</span><span class="p">[</span><span class="s1">'goal'</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">current_goal</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get the current outputs of the network</span>
|
||||
<span class="n">targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">network_inputs</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># change the targets for the taken actions</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="n">targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">info</span><span class="p">[</span><span class="s1">'future_measurements'</span><span class="p">]</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">network_inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
|
||||
|
||||
@@ -250,6 +250,9 @@
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">select_actions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">next_states</span><span class="p">,</span> <span class="n">q_st_plus_1</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">q_st_plus_1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<div class="viewcode-block" id="DQNAgent.learn_from_batch"><a class="viewcode-back" href="../../../test.html#rl_coach.agents.dqn_agent.DQNAgent.learn_from_batch">[docs]</a> <span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
@@ -261,11 +264,16 @@
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="n">selected_actions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">select_actions</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">q_st_plus_1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># add Q value samples for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># only update the action that we have actually done in this transition</span>
|
||||
<span class="n">TD_errors</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="n">new_target</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span>\
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">q_st_plus_1</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">q_st_plus_1</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">selected_actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
|
||||
<span class="n">TD_errors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">new_target</span> <span class="o">-</span> <span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]))</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">new_target</span>
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@
|
||||
|
||||
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="n">one_step_target</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> \
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> \
|
||||
<span class="n">q_st_plus_1</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">selected_actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
|
||||
|
||||
@@ -303,6 +303,9 @@
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">assert</span> <span class="kc">True</span><span class="p">,</span> <span class="s1">'The available values for targets_horizon are: 1-Step, N-Step'</span>
|
||||
|
||||
<span class="c1"># add Q value samples for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">state_value_head_targets</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># train</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="p">[</span><span class="n">state_value_head_targets</span><span class="p">])</span>
|
||||
|
||||
|
||||
@@ -313,7 +313,7 @@
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="n">bootstrapped_return_from_old_policy</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
<span class="c1"># only update the action that we have actually done in this transition</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">bootstrapped_return_from_old_policy</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># set the gradients to fetch for the DND update</span>
|
||||
@@ -342,7 +342,7 @@
|
||||
<span class="n">embedding</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">),</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">state_embedding</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_episode_state_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embedding</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span>
|
||||
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">act</span><span class="p">()</span>
|
||||
|
||||
|
||||
@@ -264,7 +264,7 @@
|
||||
<span class="c1"># calculate TD error</span>
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">q_st_online</span><span class="p">)</span>
|
||||
<span class="n">total_returns</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_step_discounted_rewards</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="n">TD_targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> \
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">()[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> \
|
||||
<span class="n">q_st_plus_1_target</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">selected_actions</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
|
||||
|
||||
@@ -268,11 +268,14 @@
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="c1"># add Q value samples for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_q_values</span><span class="p">(</span><span class="n">current_quantiles</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># get the optimal actions to take for the next states</span>
|
||||
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_q_values</span><span class="p">(</span><span class="n">next_state_quantiles</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate the Bellman update</span>
|
||||
<span class="n">batch_idx</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">))</span>
|
||||
<span class="n">batch_idx</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||||
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">(</span><span class="kc">True</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> \
|
||||
<span class="o">*</span> <span class="n">next_state_quantiles</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">target_actions</span><span class="p">]</span>
|
||||
@@ -283,9 +286,9 @@
|
||||
<span class="c1"># calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order</span>
|
||||
<span class="n">cumulative_probabilities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">atoms</span><span class="p">)</span> <span class="c1"># tau_i</span>
|
||||
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="mf">0.5</span><span class="o">*</span><span class="p">(</span><span class="n">cumulative_probabilities</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="n">cumulative_probabilities</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># tau^hat_i</span>
|
||||
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">quantile_midpoints</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="n">quantile_midpoints</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">quantile_midpoints</span><span class="p">,</span> <span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="n">sorted_quantiles</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">current_quantiles</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()])</span>
|
||||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="n">quantile_midpoints</span><span class="p">[</span><span class="n">idx</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">quantile_midpoints</span><span class="p">[</span><span class="n">idx</span><span class="p">,</span> <span class="n">sorted_quantiles</span><span class="p">[</span><span class="n">idx</span><span class="p">]]</span>
|
||||
|
||||
<span class="c1"># train</span>
|
||||
|
||||
@@ -240,9 +240,12 @@
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">algorithm</span> <span class="o">=</span> <span class="n">RainbowDQNAlgorithmParameters</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># ParameterNoiseParameters is changing the network wrapper parameters. This line needs to be done first.</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">RainbowDQNNetworkParameters</span><span class="p">()}</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">exploration</span> <span class="o">=</span> <span class="n">ParameterNoiseParameters</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="o">=</span> <span class="n">PrioritizedExperienceReplayParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">network_wrappers</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"main"</span><span class="p">:</span> <span class="n">RainbowDQNNetworkParameters</span><span class="p">()}</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
@@ -275,11 +278,14 @@
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="p">])</span>
|
||||
|
||||
<span class="c1"># add Q value samples for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">distribution_prediction_to_q_values</span><span class="p">(</span><span class="n">TD_targets</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># only update the action that we have actually done in this transition (using the Double-DQN selected actions)</span>
|
||||
<span class="n">target_actions</span> <span class="o">=</span> <span class="n">ddqn_selected_actions</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">))</span>
|
||||
|
||||
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
|
||||
<span class="n">batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">z_values</span><span class="o">.</span><span class="n">size</span><span class="p">):</span>
|
||||
<span class="c1"># we use batch.info('should_bootstrap_next_state') instead of (1 - batch.game_overs()) since with n-step,</span>
|
||||
<span class="c1"># we will not bootstrap for the last n-step transitions in the episode</span>
|
||||
|
||||
@@ -0,0 +1,554 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.soft_actor_critic_agent — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/data_stores/index.html">Data Stores</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memory_backends/index.html">Memory Backends</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/orchestrators/index.html">Orchestrators</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.soft_actor_critic_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.soft_actor_critic_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2019 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
<span class="kn">import</span> <span class="nn">copy</span>
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.policy_optimization_agent</span> <span class="k">import</span> <span class="n">PolicyOptimizationAgent</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">SACQHeadParameters</span><span class="p">,</span><span class="n">SACPolicyHeadParameters</span><span class="p">,</span><span class="n">VHeadParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.middleware_parameters</span> <span class="k">import</span> <span class="n">FCMiddlewareParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AlgorithmParameters</span><span class="p">,</span> <span class="n">NetworkParameters</span><span class="p">,</span> <span class="n">AgentParameters</span><span class="p">,</span> <span class="n">EmbedderScheme</span><span class="p">,</span> <span class="n">MiddlewareScheme</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">RunPhase</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.additive_noise</span> <span class="k">import</span> <span class="n">AdditiveNoiseParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.experience_replay</span> <span class="k">import</span> <span class="n">ExperienceReplayParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.embedder_parameters</span> <span class="k">import</span> <span class="n">InputEmbedderParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">BoxActionSpace</span>
|
||||
|
||||
|
||||
<span class="c1"># There are 3 networks in SAC implementation. All have the same topology but parameters are not shared.</span>
|
||||
<span class="c1"># The networks are:</span>
|
||||
<span class="c1"># 1. State Value Network - SACValueNetwork</span>
|
||||
<span class="c1"># 2. Soft Q Value Network - SACCriticNetwork</span>
|
||||
<span class="c1"># 3. Policy Network - SACPolicyNetwork - currently supporting only Gaussian Policy</span>
|
||||
|
||||
|
||||
<span class="c1"># 1. State Value Network - SACValueNetwork</span>
|
||||
<span class="c1"># this is the state value network in SAC.</span>
|
||||
<span class="c1"># The network is trained to predict (regression) the state value in the max-entropy settings</span>
|
||||
<span class="c1"># The objective to be minimized is given in equation (5) in the paper:</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># J(psi)= E_(s~D)[0.5*(V_psi(s)-y(s))^2]</span>
|
||||
<span class="c1"># where y(s) = E_(a~pi)[Q_theta(s,a)-log(pi(a|s))]</span>
|
||||
|
||||
|
||||
<span class="c1"># Default parameters for value network:</span>
|
||||
<span class="c1"># topology :</span>
|
||||
<span class="c1"># input embedder : EmbedderScheme.Medium (Dense(256)) , relu activation</span>
|
||||
<span class="c1"># middleware : EmbedderScheme.Medium (Dense(256)) , relu activation</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">SACValueNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">VHeadParameters</span><span class="p">(</span><span class="n">initializer</span><span class="o">=</span><span class="s1">'xavier'</span><span class="p">)]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rescale_gradient_from_head_by_factor</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0003</span> <span class="c1"># 3e-4 see appendix D in the paper</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># tau is set in SoftActorCriticAlgorithmParameters.rate_for_copying_weights_to_target</span>
|
||||
|
||||
|
||||
<span class="c1"># 2. Soft Q Value Network - SACCriticNetwork</span>
|
||||
<span class="c1"># the whole network is built in the SACQHeadParameters. we use empty input embedder and middleware</span>
|
||||
<span class="k">class</span> <span class="nc">SACCriticNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">EmbedderScheme</span><span class="o">.</span><span class="n">Empty</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">MiddlewareScheme</span><span class="o">.</span><span class="n">Empty</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">SACQHeadParameters</span><span class="p">()]</span> <span class="c1"># SACQHeadParameters includes the topology of the head</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rescale_gradient_from_head_by_factor</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0003</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
|
||||
<span class="c1"># 3. policy Network</span>
|
||||
<span class="c1"># Default parameters for policy network:</span>
|
||||
<span class="c1"># topology :</span>
|
||||
<span class="c1"># input embedder : EmbedderScheme.Medium (Dense(256)) , relu activation</span>
|
||||
<span class="c1"># middleware : EmbedderScheme = [Dense(256)] , relu activation --> scheme should be overridden in preset</span>
|
||||
<span class="k">class</span> <span class="nc">SACPolicyNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">activation_function</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">SACPolicyHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rescale_gradient_from_head_by_factor</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">async_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0003</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">l2_regularization</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># weight decay regularization. not used in the original paper</span>
|
||||
|
||||
|
||||
<span class="c1"># Algorithm Parameters</span>
|
||||
|
||||
<div class="viewcode-block" id="SoftActorCriticAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/sac.html#rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">SoftActorCriticAlgorithmParameters</span><span class="p">(</span><span class="n">AlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param num_steps_between_copying_online_weights_to_target: (StepMethod)</span>
|
||||
<span class="sd"> The number of steps between copying the online network weights to the target network weights.</span>
|
||||
|
||||
<span class="sd"> :param rate_for_copying_weights_to_target: (float)</span>
|
||||
<span class="sd"> When copying the online network weights to the target network weights, a soft update will be used, which</span>
|
||||
<span class="sd"> weight the new online network weights by rate_for_copying_weights_to_target. (Tau as defined in the paper)</span>
|
||||
|
||||
<span class="sd"> :param use_deterministic_for_evaluation: (bool)</span>
|
||||
<span class="sd"> If True, during the evaluation phase, action are chosen deterministically according to the policy mean</span>
|
||||
<span class="sd"> and not sampled from the policy distribution.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_steps_between_copying_online_weights_to_target</span> <span class="o">=</span> <span class="n">EnvironmentSteps</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rate_for_copying_weights_to_target</span> <span class="o">=</span> <span class="mf">0.005</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_deterministic_for_evaluation</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># evaluate agent using deterministic policy (i.e. take the mean value)</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">SoftActorCriticAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">SoftActorCriticAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">AdditiveNoiseParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">ExperienceReplayParameters</span><span class="p">(),</span> <span class="c1"># SAC doesnt use episodic related data</span>
|
||||
<span class="c1"># network wrappers:</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([(</span><span class="s2">"policy"</span><span class="p">,</span> <span class="n">SACPolicyNetworkParameters</span><span class="p">()),</span>
|
||||
<span class="p">(</span><span class="s2">"q"</span><span class="p">,</span> <span class="n">SACCriticNetworkParameters</span><span class="p">()),</span>
|
||||
<span class="p">(</span><span class="s2">"v"</span><span class="p">,</span> <span class="n">SACValueNetworkParameters</span><span class="p">())]))</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.soft_actor_critic_agent:SoftActorCriticAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Soft Actor Critic - https://arxiv.org/abs/1801.01290</span>
|
||||
<span class="k">class</span> <span class="nc">SoftActorCriticAgent</span><span class="p">(</span><span class="n">PolicyOptimizationAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_gradient_update_step_idx</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
|
||||
<span class="c1"># register signals to track (in learn_from_batch)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_means</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Policy_mu_avg'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_logsig</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Policy_logsig'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_logprob_sampled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Policy_logp_sampled'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_grads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'Policy_grads_sumabs'</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q1_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"Q1"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">TD_err1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"TD err1"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q2_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"Q2"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">TD_err2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"TD err2"</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">v_tgt_ns</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'V_tgt_ns'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">v_onl_ys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s1">'V_onl_ys'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_signal</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_signal</span><span class="p">(</span><span class="s2">"actions"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="c1">#########################################</span>
|
||||
<span class="c1"># need to update the following networks:</span>
|
||||
<span class="c1"># 1. actor (policy)</span>
|
||||
<span class="c1"># 2. state value (v)</span>
|
||||
<span class="c1"># 3. critic (q1 and q2)</span>
|
||||
<span class="c1"># 4. target network - probably already handled by V</span>
|
||||
|
||||
<span class="c1">#########################################</span>
|
||||
<span class="c1"># define the networks to be used</span>
|
||||
|
||||
<span class="c1"># State Value Network</span>
|
||||
<span class="n">value_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'v'</span><span class="p">]</span>
|
||||
<span class="n">value_network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'v'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># Critic Network</span>
|
||||
<span class="n">q_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'q'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
|
||||
<span class="n">q_head</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">q_network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'q'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># Actor (policy) Network</span>
|
||||
<span class="n">policy_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'policy'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
|
||||
<span class="n">policy_network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'policy'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
|
||||
<span class="c1">##########################################</span>
|
||||
<span class="c1"># 1. updating the actor - according to (13) in the paper</span>
|
||||
<span class="n">policy_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">policy_network_keys</span><span class="p">))</span>
|
||||
<span class="n">policy_results</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">policy_inputs</span><span class="p">)</span>
|
||||
|
||||
<span class="n">policy_mu</span><span class="p">,</span> <span class="n">policy_std</span><span class="p">,</span> <span class="n">sampled_raw_actions</span><span class="p">,</span> <span class="n">sampled_actions</span><span class="p">,</span> <span class="n">sampled_actions_logprob</span><span class="p">,</span> \
|
||||
<span class="n">sampled_actions_logprob_mean</span> <span class="o">=</span> <span class="n">policy_results</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_means</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">policy_mu</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_logsig</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">policy_std</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_logprob_sampled</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">sampled_actions_logprob_mean</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get the state-action values for the replayed states and their corresponding actions from the policy</span>
|
||||
<span class="n">q_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">q_network_keys</span><span class="p">))</span>
|
||||
<span class="n">q_inputs</span><span class="p">[</span><span class="s1">'output_0_0'</span><span class="p">]</span> <span class="o">=</span> <span class="n">sampled_actions</span>
|
||||
<span class="n">log_target</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">q_inputs</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># log internal q values</span>
|
||||
<span class="n">q1_vals</span><span class="p">,</span> <span class="n">q2_vals</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">q_inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">q_head</span><span class="o">.</span><span class="n">q1_output</span><span class="p">,</span> <span class="n">q_head</span><span class="o">.</span><span class="n">q2_output</span><span class="p">])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q1_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q1_vals</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q2_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q2_vals</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate the gradients according to (13)</span>
|
||||
<span class="c1"># get the gradients of log_prob w.r.t the weights (parameters) - indicated as phi in the paper</span>
|
||||
<span class="n">initial_feed_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">policy_network</span><span class="o">.</span><span class="n">gradients_weights_ph</span><span class="p">[</span><span class="mi">5</span><span class="p">]:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)}</span>
|
||||
<span class="n">dlogp_dphi</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">policy_inputs</span><span class="p">,</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">policy_network</span><span class="o">.</span><span class="n">weighted_gradients</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span>
|
||||
<span class="n">initial_feed_dict</span><span class="o">=</span><span class="n">initial_feed_dict</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># calculate dq_da</span>
|
||||
<span class="n">dq_da</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">q_inputs</span><span class="p">,</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">q_network</span><span class="o">.</span><span class="n">gradients_wrt_inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="s1">'output_0_0'</span><span class="p">])</span>
|
||||
|
||||
<span class="c1"># calculate da_dphi</span>
|
||||
<span class="n">initial_feed_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">policy_network</span><span class="o">.</span><span class="n">gradients_weights_ph</span><span class="p">[</span><span class="mi">3</span><span class="p">]:</span> <span class="n">dq_da</span><span class="p">}</span>
|
||||
<span class="n">dq_dphi</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">policy_inputs</span><span class="p">,</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">policy_network</span><span class="o">.</span><span class="n">weighted_gradients</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span>
|
||||
<span class="n">initial_feed_dict</span><span class="o">=</span><span class="n">initial_feed_dict</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># now given dlogp_dphi, dq_dphi we need to calculate the policy gradients according to (13)</span>
|
||||
<span class="n">policy_grads</span> <span class="o">=</span> <span class="p">[</span><span class="n">dlogp_dphi</span><span class="p">[</span><span class="n">l</span><span class="p">]</span> <span class="o">-</span> <span class="n">dq_dphi</span><span class="p">[</span><span class="n">l</span><span class="p">]</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dlogp_dphi</span><span class="p">))]</span>
|
||||
|
||||
<span class="c1"># apply the gradients to policy networks</span>
|
||||
<span class="n">policy_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">policy_grads</span><span class="p">)</span>
|
||||
<span class="n">grads_sumabs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">([</span><span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">policy_grads</span><span class="p">[</span><span class="n">l</span><span class="p">]))</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">policy_grads</span><span class="p">))])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">policy_grads</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">grads_sumabs</span><span class="p">)</span>
|
||||
|
||||
<span class="c1">##########################################</span>
|
||||
<span class="c1"># 2. updating the state value online network weights</span>
|
||||
<span class="c1"># done by calculating the targets for the v head according to (5) in the paper</span>
|
||||
<span class="c1"># value_targets = log_targets-sampled_actions_logprob</span>
|
||||
<span class="n">value_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">value_network_keys</span><span class="p">))</span>
|
||||
<span class="n">value_targets</span> <span class="o">=</span> <span class="n">log_target</span> <span class="o">-</span> <span class="n">sampled_actions_logprob</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">v_onl_ys</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">value_targets</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># call value_network apply gradients with this target</span>
|
||||
<span class="n">value_loss</span> <span class="o">=</span> <span class="n">value_network</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">train_on_batch</span><span class="p">(</span><span class="n">value_inputs</span><span class="p">,</span> <span class="n">value_targets</span><span class="p">[:,</span><span class="kc">None</span><span class="p">])[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="c1">##########################################</span>
|
||||
<span class="c1"># 3. updating the critic (q networks)</span>
|
||||
<span class="c1"># updating q networks according to (7) in the paper</span>
|
||||
|
||||
<span class="c1"># define the input to the q network: state has been already updated previously, but now we need</span>
|
||||
<span class="c1"># the actions from the batch (and not those sampled by the policy)</span>
|
||||
<span class="n">q_inputs</span><span class="p">[</span><span class="s1">'output_0_0'</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># define the targets : scale_reward * reward + (1-terminal)*discount*v_target_next_state</span>
|
||||
<span class="c1"># define v_target_next_state</span>
|
||||
<span class="n">value_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">next_states</span><span class="p">(</span><span class="n">value_network_keys</span><span class="p">))</span>
|
||||
<span class="n">v_target_next_state</span> <span class="o">=</span> <span class="n">value_network</span><span class="o">.</span><span class="n">target_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">value_inputs</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">v_tgt_ns</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">v_target_next_state</span><span class="p">)</span>
|
||||
<span class="c1"># Note: reward is assumed to be rescaled by RewardRescaleFilter in the preset parameters</span>
|
||||
<span class="n">TD_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> \
|
||||
<span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">batch</span><span class="o">.</span><span class="n">game_overs</span><span class="p">(</span><span class="n">expand_dims</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span> <span class="o">*</span> <span class="n">v_target_next_state</span>
|
||||
|
||||
<span class="c1"># call critic network update</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="n">q_network</span><span class="o">.</span><span class="n">train_on_batch</span><span class="p">(</span><span class="n">q_inputs</span><span class="p">,</span> <span class="n">TD_targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="p">[</span><span class="n">q_head</span><span class="o">.</span><span class="n">q1_loss</span><span class="p">,</span> <span class="n">q_head</span><span class="o">.</span><span class="n">q2_loss</span><span class="p">])</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
<span class="n">q1_loss</span><span class="p">,</span> <span class="n">q2_loss</span> <span class="o">=</span> <span class="n">result</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">TD_err1</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q1_loss</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">TD_err2</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q2_loss</span><span class="p">)</span>
|
||||
|
||||
<span class="c1">##########################################</span>
|
||||
<span class="c1"># 4. updating the value target network</span>
|
||||
<span class="c1"># I just need to set the parameter rate_for_copying_weights_to_target in the agent parameters to be 1-tau</span>
|
||||
<span class="c1"># where tau is the hyper parameter as defined in sac original implementation</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> get the mean and stdev of the policy distribution given 'states'</span>
|
||||
<span class="sd"> :param states: the states for which we need to sample actions from the policy</span>
|
||||
<span class="sd"> :return: mean and stdev</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'policy'</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'policy'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="c1"># since the algorithm works with experience replay buffer (non-episodic),</span>
|
||||
<span class="c1"># we cant use the policy optimization train method. we need Agent.train</span>
|
||||
<span class="c1"># note that since in Agent.train there is no apply_gradients, we need to do it in learn from batch</span>
|
||||
<span class="k">return</span> <span class="n">Agent</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> choose_action - chooses the most likely action</span>
|
||||
<span class="sd"> if 'deterministic' - take the mean of the policy which is the prediction of the policy network.</span>
|
||||
<span class="sd"> else - use the exploration policy</span>
|
||||
<span class="sd"> :param curr_state:</span>
|
||||
<span class="sd"> :return: action wrapped in ActionInfo</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">):</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"SAC works only for continuous control problems"</span><span class="p">)</span>
|
||||
<span class="c1"># convert to batch so we can run it through the network</span>
|
||||
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">'policy'</span><span class="p">)</span>
|
||||
<span class="c1"># use the online network for prediction</span>
|
||||
<span class="n">policy_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'policy'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
|
||||
<span class="n">policy_head</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">output_heads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="n">policy_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">,</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">policy_head</span><span class="o">.</span><span class="n">policy_mean</span><span class="p">,</span> <span class="n">policy_head</span><span class="o">.</span><span class="n">actions</span><span class="p">])</span>
|
||||
<span class="n">action_mean</span><span class="p">,</span> <span class="n">action_sample</span> <span class="o">=</span> <span class="n">result</span>
|
||||
|
||||
<span class="c1"># if using deterministic policy, take the mean values. else, use exploration policy to sample from the pdf</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">use_deterministic_for_evaluation</span><span class="p">:</span>
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="n">action_mean</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="n">action_sample</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_signal</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>
|
||||
|
||||
<span class="n">action_info</span> <span class="o">=</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">action_info</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/language_data.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
@@ -193,16 +193,17 @@
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">StateType</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">Batch</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.prioritized_experience_replay</span> <span class="k">import</span> <span class="n">PrioritizedExperienceReplay</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">deepcopy</span>
|
||||
|
||||
<span class="c1">## This is an abstract agent - there is no learn_from_batch method ##</span>
|
||||
|
||||
@@ -229,8 +230,9 @@
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">return</span> <span class="n">actions_q_values</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">))</span>
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">),</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">update_transition_priorities_and_get_weights</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">TD_errors</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="c1"># update errors in prioritized replay buffer</span>
|
||||
@@ -259,10 +261,12 @@
|
||||
<span class="c1"># this is for bootstrapped dqn</span>
|
||||
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span> <span class="o">==</span> <span class="nb">list</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">last_action_values</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># store the q values statistics for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
|
||||
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">q_value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q_value</span><span class="p">)</span>
|
||||
|
||||
@@ -276,6 +280,77 @@
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">"ValueOptimizationAgent is an abstract agent. Not to be used directly."</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">run_off_policy_evaluation</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on</span>
|
||||
<span class="sd"> an evaluation dataset, which was collected by another policy(ies).</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">ope_manager</span>
|
||||
<span class="n">dataset_as_episodes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'get_all_complete_episodes_from_to'</span><span class="p">,</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'get_last_training_set_episode_id'</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'num_complete_episodes'</span><span class="p">)))</span>
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataset_as_episodes</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'train_to_eval_ratio is too high causing the evaluation set to be empty. '</span>
|
||||
<span class="s1">'Consider decreasing its value.'</span><span class="p">)</span>
|
||||
|
||||
<span class="n">ips</span><span class="p">,</span> <span class="n">dm</span><span class="p">,</span> <span class="n">dr</span><span class="p">,</span> <span class="n">seq_dr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ope_manager</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span>
|
||||
<span class="n">dataset_as_episodes</span><span class="o">=</span><span class="n">dataset_as_episodes</span><span class="p">,</span>
|
||||
<span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="n">discount_factor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span><span class="p">,</span>
|
||||
<span class="n">reward_model</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span>
|
||||
<span class="n">q_network</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span>
|
||||
<span class="n">network_keys</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
|
||||
|
||||
<span class="c1"># get the estimators out to the screen</span>
|
||||
<span class="n">log</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'Epoch'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'IPS'</span><span class="p">]</span> <span class="o">=</span> <span class="n">ips</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'DM'</span><span class="p">]</span> <span class="o">=</span> <span class="n">dm</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'DR'</span><span class="p">]</span> <span class="o">=</span> <span class="n">dr</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'Sequential-DR'</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq_dr</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_dict</span><span class="p">(</span><span class="n">log</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="s1">'Off-Policy Evaluation'</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get the estimators out to dashboard</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_current_time</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Inverse Propensity Score'</span><span class="p">,</span> <span class="n">ips</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Direct Method Reward'</span><span class="p">,</span> <span class="n">dm</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Doubly Robust'</span><span class="p">,</span> <span class="n">dr</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Sequential Doubly Robust'</span><span class="p">,</span> <span class="n">seq_dr</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_reward_model_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
<span class="n">current_rewards_prediction_for_all_actions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="n">current_rewards_prediction_for_all_actions</span><span class="p">[</span><span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()</span>
|
||||
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span>
|
||||
<span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">current_rewards_prediction_for_all_actions</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">improve_reward_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Train a reward model to be used by the doubly-robust estimator</span>
|
||||
|
||||
<span class="sd"> :param epochs: The total number of epochs to use for training a reward model</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
|
||||
<span class="c1"># this is fitted from the training dataset</span>
|
||||
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">total_transitions_processed</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'get_shuffled_data_generator'</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)):</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
||||
<span class="n">loss</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_reward_model_loss</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
||||
<span class="n">total_transitions_processed</span> <span class="o">+=</span> <span class="n">batch</span><span class="o">.</span><span class="n">size</span>
|
||||
|
||||
<span class="n">log</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'Epoch'</span><span class="p">]</span> <span class="o">=</span> <span class="n">epoch</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'loss'</span><span class="p">]</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">/</span> <span class="n">total_transitions_processed</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_dict</span><span class="p">(</span><span class="n">log</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="s1">'Training Reward Model'</span><span class="p">)</span>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -205,6 +205,7 @@
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">TrainingSteps</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">GradientClippingMethod</span><span class="p">,</span> <span class="n">RunPhase</span><span class="p">,</span> \
|
||||
<span class="n">SelectedPhaseOnlyDumpFilter</span><span class="p">,</span> <span class="n">MaxDumpFilter</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.filters.filter</span> <span class="k">import</span> <span class="n">NoInputFilter</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">Frameworks</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
|
||||
@@ -379,9 +380,6 @@
|
||||
<span class="c1"># distributed agents params</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">share_statistics_between_workers</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
<span class="c1"># intrinsic reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">scale_external_reward_by_intrinsic_reward_value</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="c1"># n-step returns</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">n_step</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span> <span class="c1"># calculate the total return (no bootstrap, by default)</span>
|
||||
|
||||
@@ -470,7 +468,8 @@
|
||||
<span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
|
||||
<span class="n">replace_mse_with_huber_loss</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">create_target_network</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">tensorflow_support</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||||
<span class="n">tensorflow_support</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">softmax_temperature</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param force_cpu:</span>
|
||||
<span class="sd"> Force the neural networks to run on the CPU even if a GPU is available</span>
|
||||
@@ -553,6 +552,8 @@
|
||||
<span class="sd"> online network at will.</span>
|
||||
<span class="sd"> :param tensorflow_support:</span>
|
||||
<span class="sd"> A flag which specifies if the network is supported by the TensorFlow framework.</span>
|
||||
<span class="sd"> :param softmax_temperature:</span>
|
||||
<span class="sd"> If a softmax is present in the network head output, use this temperature</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">framework</span> <span class="o">=</span> <span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span>
|
||||
@@ -583,16 +584,19 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="n">heads_parameters</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_separate_networks_per_head</span> <span class="o">=</span> <span class="n">use_separate_networks_per_head</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="n">optimizer_type</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="n">replace_mse_with_huber_loss</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="n">create_target_network</span>
|
||||
|
||||
<span class="c1"># Framework support</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tensorflow_support</span> <span class="o">=</span> <span class="n">tensorflow_support</span>
|
||||
|
||||
<span class="c1"># Hyper-Parameter values</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_epsilon</span> <span class="o">=</span> <span class="n">optimizer_epsilon</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta1</span> <span class="o">=</span> <span class="n">adam_optimizer_beta1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta2</span> <span class="o">=</span> <span class="n">adam_optimizer_beta2</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rms_prop_optimizer_decay</span> <span class="o">=</span> <span class="n">rms_prop_optimizer_decay</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="n">replace_mse_with_huber_loss</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="n">create_target_network</span>
|
||||
|
||||
<span class="c1"># Framework support</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tensorflow_support</span> <span class="o">=</span> <span class="n">tensorflow_support</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">softmax_temperature</span> <span class="o">=</span> <span class="n">softmax_temperature</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NetworkComponentParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
|
||||
@@ -723,6 +727,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_a_lowest_level_agent</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">task_parameters</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_batch_rl_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
@@ -730,18 +735,22 @@
|
||||
|
||||
|
||||
<div class="viewcode-block" id="TaskParameters"><a class="viewcode-back" href="../../components/additional_parameters.html#rl_coach.base_parameters.TaskParameters">[docs]</a><span class="k">class</span> <span class="nc">TaskParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">framework_type</span><span class="p">:</span> <span class="n">Frameworks</span><span class="o">=</span><span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">use_cpu</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">framework_type</span><span class="p">:</span> <span class="n">Frameworks</span><span class="o">=</span><span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">use_cpu</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">experiment_path</span><span class="o">=</span><span class="s1">'/tmp'</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_restore_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">export_onnx_graph</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">apply_stop_condition</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">num_gpu</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
||||
<span class="n">checkpoint_restore_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">export_onnx_graph</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">apply_stop_condition</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">num_gpu</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param framework_type: deep learning framework type. currently only tensorflow is supported</span>
|
||||
<span class="sd"> :param evaluate_only: the task will be used only for evaluating the model</span>
|
||||
<span class="sd"> :param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.</span>
|
||||
<span class="sd"> A value of 0 means that task will be evaluated for an infinite number of steps.</span>
|
||||
<span class="sd"> :param use_cpu: use the cpu for this task</span>
|
||||
<span class="sd"> :param experiment_path: the path to the directory which will store all the experiment outputs</span>
|
||||
<span class="sd"> :param seed: a seed to use for the random numbers generator</span>
|
||||
<span class="sd"> :param checkpoint_save_secs: the number of seconds between each checkpoint saving</span>
|
||||
<span class="sd"> :param checkpoint_restore_dir: the directory to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_restore_dir:</span>
|
||||
<span class="sd"> [DEPECRATED - will be removed in one of the next releases - switch to checkpoint_restore_path]</span>
|
||||
<span class="sd"> the dir to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_restore_path: the path to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_save_dir: the directory to store the checkpoints in</span>
|
||||
<span class="sd"> :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved</span>
|
||||
<span class="sd"> :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate</span>
|
||||
@@ -753,7 +762,13 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_cpu</span> <span class="o">=</span> <span class="n">use_cpu</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">experiment_path</span> <span class="o">=</span> <span class="n">experiment_path</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_save_secs</span> <span class="o">=</span> <span class="n">checkpoint_save_secs</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_restore_dir</span> <span class="o">=</span> <span class="n">checkpoint_restore_dir</span>
|
||||
<span class="k">if</span> <span class="n">checkpoint_restore_dir</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">'TaskParameters.checkpoint_restore_dir is DEPECRATED and will be removed in one of the next '</span>
|
||||
<span class="s1">'releases. Please switch to using TaskParameters.checkpoint_restore_path, with your '</span>
|
||||
<span class="s1">'directory path. '</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_restore_path</span> <span class="o">=</span> <span class="n">checkpoint_restore_dir</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_restore_path</span> <span class="o">=</span> <span class="n">checkpoint_restore_path</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_save_dir</span> <span class="o">=</span> <span class="n">checkpoint_save_dir</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">seed</span> <span class="o">=</span> <span class="n">seed</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">export_onnx_graph</span> <span class="o">=</span> <span class="n">export_onnx_graph</span>
|
||||
@@ -763,13 +778,14 @@
|
||||
|
||||
<div class="viewcode-block" id="DistributedTaskParameters"><a class="viewcode-back" href="../../components/additional_parameters.html#rl_coach.base_parameters.DistributedTaskParameters">[docs]</a><span class="k">class</span> <span class="nc">DistributedTaskParameters</span><span class="p">(</span><span class="n">TaskParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">framework_type</span><span class="p">:</span> <span class="n">Frameworks</span><span class="p">,</span> <span class="n">parameters_server_hosts</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">worker_hosts</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">job_type</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||||
<span class="n">task_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">num_tasks</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">task_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_tasks</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">num_training_tasks</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">use_cpu</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">experiment_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dnd</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">shared_memory_scratchpad</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_restore_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">shared_memory_scratchpad</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_restore_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">export_onnx_graph</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">apply_stop_condition</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param framework_type: deep learning framework type. currently only tensorflow is supported</span>
|
||||
<span class="sd"> :param evaluate_only: the task will be used only for evaluating the model</span>
|
||||
<span class="sd"> :param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.</span>
|
||||
<span class="sd"> A value of 0 means that task will be evaluated for an infinite number of steps.</span>
|
||||
<span class="sd"> :param parameters_server_hosts: comma-separated list of hostname:port pairs to which the parameter servers are</span>
|
||||
<span class="sd"> assigned</span>
|
||||
<span class="sd"> :param worker_hosts: comma-separated list of hostname:port pairs to which the workers are assigned</span>
|
||||
@@ -782,7 +798,7 @@
|
||||
<span class="sd"> :param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.</span>
|
||||
<span class="sd"> :param seed: a seed to use for the random numbers generator</span>
|
||||
<span class="sd"> :param checkpoint_save_secs: the number of seconds between each checkpoint saving</span>
|
||||
<span class="sd"> :param checkpoint_restore_dir: the directory to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_restore_path: the path to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_save_dir: the directory to store the checkpoints in</span>
|
||||
<span class="sd"> :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved</span>
|
||||
<span class="sd"> :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate</span>
|
||||
@@ -790,7 +806,7 @@
|
||||
<span class="sd"> """</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">framework_type</span><span class="o">=</span><span class="n">framework_type</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="o">=</span><span class="n">evaluate_only</span><span class="p">,</span> <span class="n">use_cpu</span><span class="o">=</span><span class="n">use_cpu</span><span class="p">,</span>
|
||||
<span class="n">experiment_path</span><span class="o">=</span><span class="n">experiment_path</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="n">checkpoint_save_secs</span><span class="p">,</span>
|
||||
<span class="n">checkpoint_restore_dir</span><span class="o">=</span><span class="n">checkpoint_restore_dir</span><span class="p">,</span> <span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="n">checkpoint_save_dir</span><span class="p">,</span>
|
||||
<span class="n">checkpoint_restore_path</span><span class="o">=</span><span class="n">checkpoint_restore_path</span><span class="p">,</span> <span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="n">checkpoint_save_dir</span><span class="p">,</span>
|
||||
<span class="n">export_onnx_graph</span><span class="o">=</span><span class="n">export_onnx_graph</span><span class="p">,</span> <span class="n">apply_stop_condition</span><span class="o">=</span><span class="n">apply_stop_condition</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">parameters_server_hosts</span> <span class="o">=</span> <span class="n">parameters_server_hosts</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">worker_hosts</span> <span class="o">=</span> <span class="n">worker_hosts</span>
|
||||
|
||||
@@ -193,9 +193,10 @@
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">namedtuple</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">copy</span>
|
||||
<span class="kn">import</span> <span class="nn">math</span>
|
||||
<span class="kn">from</span> <span class="nn">enum</span> <span class="k">import</span> <span class="n">Enum</span>
|
||||
<span class="kn">from</span> <span class="nn">random</span> <span class="k">import</span> <span class="n">shuffle</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Type</span>
|
||||
@@ -218,6 +219,17 @@
|
||||
<span class="n">Measurements</span> <span class="o">=</span> <span class="mi">4</span>
|
||||
|
||||
|
||||
<span class="n">Record</span> <span class="o">=</span> <span class="n">namedtuple</span><span class="p">(</span><span class="s1">'Record'</span><span class="p">,</span> <span class="p">[</span><span class="s1">'name'</span><span class="p">,</span> <span class="s1">'label'</span><span class="p">])</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">TimeTypes</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
|
||||
<span class="n">EpisodeNumber</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'Episode #'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">'Episode #'</span><span class="p">)</span>
|
||||
<span class="n">TrainingIteration</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'Training Iter'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">'Training Iteration'</span><span class="p">)</span>
|
||||
<span class="n">EnvironmentSteps</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'Total steps'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">'Total steps (per worker)'</span><span class="p">)</span>
|
||||
<span class="n">WallClockTime</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'Wall-Clock Time'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">'Wall-Clock Time (minutes)'</span><span class="p">)</span>
|
||||
<span class="n">Epoch</span> <span class="o">=</span> <span class="n">Record</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'Epoch'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">'Epoch #'</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="c1"># step methods</span>
|
||||
|
||||
<span class="k">class</span> <span class="nc">StepMethod</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||||
@@ -232,6 +244,37 @@
|
||||
<span class="k">def</span> <span class="nf">num_steps</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_num_steps</span> <span class="o">=</span> <span class="n">val</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">==</span> <span class="n">other</span><span class="o">.</span><span class="n">num_steps</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> divide this step method with other. If other is an integer, returns an object of the same</span>
|
||||
<span class="sd"> type as self. If other is the same type of self, returns an integer. In either case, any</span>
|
||||
<span class="sd"> floating point value is rounded up under the assumption that if we are dividing Steps, we</span>
|
||||
<span class="sd"> would rather overestimate than underestimate.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)):</span>
|
||||
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">/</span> <span class="n">other</span><span class="o">.</span><span class="n">num_steps</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)(</span><span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">/</span> <span class="n">other</span><span class="p">))</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"cannot divide </span><span class="si">{}</span><span class="s2"> by </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="nb">type</span><span class="p">(</span><span class="n">other</span><span class="p">)))</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">__rtruediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> divide this step method with other. If other is an integer, returns an object of the same</span>
|
||||
<span class="sd"> type as self. If other is the same type of self, returns an integer. In either case, any</span>
|
||||
<span class="sd"> floating point value is rounded up under the assumption that if we are dividing Steps, we</span>
|
||||
<span class="sd"> would rather overestimate than underestimate.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)):</span>
|
||||
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">other</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)(</span><span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">other</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span><span class="p">))</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"cannot divide </span><span class="si">{}</span><span class="s2"> by </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">other</span><span class="p">),</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)))</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">Frames</span><span class="p">(</span><span class="n">StepMethod</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_steps</span><span class="p">):</span>
|
||||
@@ -429,6 +472,9 @@
|
||||
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="o">.</span><span class="n">keys</span><span class="p">(),</span> <span class="n">new_info</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">new_info</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">update_info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_info</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">new_info</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">__copy__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="n">new_transition</span> <span class="o">=</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)()</span>
|
||||
<span class="n">new_transition</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
|
||||
@@ -510,8 +556,7 @@
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">,</span> <span class="n">all_action_probabilities</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">state_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">max_action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">action_intrinsic_reward</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
|
||||
<span class="n">action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">state_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">max_action_value</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param action: the action</span>
|
||||
<span class="sd"> :param all_action_probabilities: the probability that the action was given when selecting it</span>
|
||||
@@ -520,8 +565,6 @@
|
||||
<span class="sd"> :param max_action_value: in case this is an action that was selected randomly, this is the value of the action</span>
|
||||
<span class="sd"> that received the maximum value. if no value is given, the action is assumed to be the</span>
|
||||
<span class="sd"> action with the maximum value</span>
|
||||
<span class="sd"> :param action_intrinsic_reward: can contain any intrinsic reward that the agent wants to add to this action</span>
|
||||
<span class="sd"> selection</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action</span> <span class="o">=</span> <span class="n">action</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">all_action_probabilities</span> <span class="o">=</span> <span class="n">all_action_probabilities</span>
|
||||
@@ -530,8 +573,7 @@
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">max_action_value</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_action_value</span> <span class="o">=</span> <span class="n">action_value</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_action_value</span> <span class="o">=</span> <span class="n">max_action_value</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_intrinsic_reward</span> <span class="o">=</span> <span class="n">action_intrinsic_reward</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_action_value</span> <span class="o">=</span> <span class="n">max_action_value</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="Batch"><a class="viewcode-back" href="../../components/core_types.html#rl_coach.core_types.Batch">[docs]</a><span class="k">class</span> <span class="nc">Batch</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||||
@@ -1047,6 +1089,19 @@
|
||||
<span class="k">return</span> <span class="kc">True</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="kc">False</span>
|
||||
|
||||
|
||||
<span class="c1"># TODO move to a NamedTuple, once we move to Python3.6</span>
|
||||
<span class="c1"># https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple/34269877</span>
|
||||
<span class="k">class</span> <span class="nc">CsvDataset</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filepath</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">is_episodic</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">filepath</span> <span class="o">=</span> <span class="n">filepath</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_episodic</span> <span class="o">=</span> <span class="n">is_episodic</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">PickledReplayBuffer</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filepath</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">filepath</span> <span class="o">=</span> <span class="n">filepath</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -463,7 +463,12 @@
|
||||
<span class="nb">print</span><span class="p">(</span><span class="s2">"Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while deleting NFS PVC"</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="kc">False</span>
|
||||
|
||||
<span class="k">return</span> <span class="kc">True</span></div>
|
||||
<span class="k">return</span> <span class="kc">True</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">setup_checkpoint_dir</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crd</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">crd</span><span class="p">:</span>
|
||||
<span class="c1"># TODO: find a way to upload this to the deployed nfs store.</span>
|
||||
<span class="k">pass</span></div>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -257,6 +257,9 @@
|
||||
<span class="k">return</span> <span class="kc">True</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">save_to_store</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_save_to_store</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_save_to_store</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and</span>
|
||||
<span class="sd"> uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.</span>
|
||||
@@ -268,24 +271,32 @@
|
||||
<span class="c1"># Acquire lock</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">put_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">LOCKFILE</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">(</span><span class="sa">b</span><span class="s1">''</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="n">state_file</span> <span class="o">=</span> <span class="n">CheckpointStateFile</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">))</span>
|
||||
<span class="n">state_file</span> <span class="o">=</span> <span class="n">CheckpointStateFile</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">))</span>
|
||||
<span class="k">if</span> <span class="n">state_file</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
|
||||
<span class="n">ckpt_state</span> <span class="o">=</span> <span class="n">state_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
|
||||
<span class="n">checkpoint_file</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">for</span> <span class="n">root</span><span class="p">,</span> <span class="n">dirs</span><span class="p">,</span> <span class="n">files</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">walk</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">root</span><span class="p">,</span> <span class="n">dirs</span><span class="p">,</span> <span class="n">files</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">walk</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">files</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">filename</span> <span class="o">==</span> <span class="n">CheckpointStateFile</span><span class="o">.</span><span class="n">checkpoint_state_filename</span><span class="p">:</span>
|
||||
<span class="n">checkpoint_file</span> <span class="o">=</span> <span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
|
||||
<span class="k">continue</span>
|
||||
<span class="k">if</span> <span class="n">filename</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="n">ckpt_state</span><span class="o">.</span><span class="n">name</span><span class="p">):</span>
|
||||
<span class="n">abs_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
|
||||
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">)</span>
|
||||
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fput_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">rel_name</span><span class="p">,</span> <span class="n">abs_name</span><span class="p">)</span>
|
||||
|
||||
<span class="n">abs_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_file</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">checkpoint_file</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
|
||||
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">)</span>
|
||||
<span class="n">rel_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">relpath</span><span class="p">(</span><span class="n">abs_name</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fput_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">rel_name</span><span class="p">,</span> <span class="n">abs_name</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># upload Finished if present</span>
|
||||
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">FINISHED</span><span class="o">.</span><span class="n">value</span><span class="p">)):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">put_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">FINISHED</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">(</span><span class="sa">b</span><span class="s1">''</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># upload Ready if present</span>
|
||||
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">)):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">put_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">(</span><span class="sa">b</span><span class="s1">''</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># release lock</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">remove_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">LOCKFILE</span><span class="o">.</span><span class="n">value</span><span class="p">)</span>
|
||||
|
||||
@@ -301,6 +312,7 @@
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span><span class="p">,</span> <span class="s1">'gifs'</span><span class="p">)):</span>
|
||||
<span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span><span class="p">,</span> <span class="s1">'gifs'</span><span class="p">)):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fput_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">expt_dir</span><span class="p">,</span> <span class="s1">'gifs'</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
|
||||
|
||||
<span class="k">except</span> <span class="n">ResponseError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="s2">"Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while saving to S3"</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
|
||||
|
||||
@@ -337,6 +349,18 @@
|
||||
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||||
<span class="k">pass</span>
|
||||
|
||||
<span class="c1"># Check if there's a ready file</span>
|
||||
<span class="n">objects</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">list_objects_v2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">next</span><span class="p">(</span><span class="n">objects</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fget_object</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">,</span>
|
||||
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">SyncFiles</span><span class="o">.</span><span class="n">TRAINER_READY</span><span class="o">.</span><span class="n">value</span><span class="p">))</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||||
<span class="k">pass</span>
|
||||
|
||||
<span class="n">checkpoint_state</span> <span class="o">=</span> <span class="n">state_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="n">checkpoint_state</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">objects</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">list_objects_v2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="n">checkpoint_state</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">recursive</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
@@ -346,7 +370,11 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mc</span><span class="o">.</span><span class="n">fget_object</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">bucket_name</span><span class="p">,</span> <span class="n">obj</span><span class="o">.</span><span class="n">object_name</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
|
||||
|
||||
<span class="k">except</span> <span class="n">ResponseError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="s2">"Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while loading from S3"</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span></div>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="s2">"Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while loading from S3"</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">setup_checkpoint_dir</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crd</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">crd</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_save_to_store</span><span class="p">(</span><span class="n">crd</span><span class="p">)</span></div>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -480,7 +480,6 @@
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># frame skip and max between consecutive frames</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_robotics_env</span> <span class="o">=</span> <span class="s1">'robotics'</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_mujoco_env</span> <span class="o">=</span> <span class="s1">'mujoco'</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_roboschool_env</span> <span class="o">=</span> <span class="s1">'roboschool'</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_atari_env</span> <span class="o">=</span> <span class="s1">'Atari'</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="vm">__class__</span><span class="p">)</span>
|
||||
@@ -501,7 +500,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">state_space</span> <span class="o">=</span> <span class="n">StateSpace</span><span class="p">({})</span>
|
||||
|
||||
<span class="c1"># observations</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">,</span> <span class="n">gym</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">dict_space</span><span class="o">.</span><span class="n">Dict</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">,</span> <span class="n">gym</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">dict</span><span class="o">.</span><span class="n">Dict</span><span class="p">):</span>
|
||||
<span class="n">state_space</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">}</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">state_space</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="o">.</span><span class="n">spaces</span>
|
||||
@@ -665,38 +664,11 @@
|
||||
<span class="c1"># initialize the number of lives</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_update_ale_lives</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_set_mujoco_camera</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">camera_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> This function can be used to set the camera for rendering the mujoco simulator</span>
|
||||
<span class="sd"> :param camera_idx: The index of the camera to use. Should be defined in the model</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span><span class="o">.</span><span class="n">cam</span><span class="o">.</span><span class="n">fixedcamid</span> <span class="o">!=</span> <span class="n">camera_idx</span><span class="p">:</span>
|
||||
<span class="kn">from</span> <span class="nn">mujoco_py.generated</span> <span class="k">import</span> <span class="n">const</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span><span class="o">.</span><span class="n">cam</span><span class="o">.</span><span class="n">type</span> <span class="o">=</span> <span class="n">const</span><span class="o">.</span><span class="n">CAMERA_FIXED</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">viewer</span><span class="o">.</span><span class="n">cam</span><span class="o">.</span><span class="n">fixedcamid</span> <span class="o">=</span> <span class="n">camera_idx</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_get_robotics_image</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">()</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">unwrapped</span><span class="o">.</span><span class="n">_get_viewer</span><span class="p">()</span><span class="o">.</span><span class="n">read_pixels</span><span class="p">(</span><span class="mi">1600</span><span class="p">,</span> <span class="mi">900</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="kc">False</span><span class="p">)[::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">misc</span><span class="o">.</span><span class="n">imresize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="p">(</span><span class="mi">270</span><span class="p">,</span> <span class="mi">480</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
|
||||
<span class="k">return</span> <span class="n">image</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_render</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'human'</span><span class="p">)</span>
|
||||
<span class="c1"># required for setting up a fixed camera for mujoco</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_mujoco_env</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_roboschool_env</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_set_mujoco_camera</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_rendered_image</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_robotics_env</span><span class="p">:</span>
|
||||
<span class="c1"># necessary for fetch since the rendered image is cropped to an irrelevant part of the simulator</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_robotics_image</span><span class="p">()</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">render</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'rgb_array'</span><span class="p">)</span>
|
||||
<span class="c1"># required for setting up a fixed camera for mujoco</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_mujoco_env</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_roboschool_env</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_set_mujoco_camera</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">image</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_target_success_rate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span>
|
||||
|
||||
@@ -179,6 +179,7 @@
|
||||
|
||||
<h1>Source code for rl_coach.memories.episodic.episodic_experience_replay</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2017 Intel Corporation</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
@@ -193,14 +194,19 @@
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="kn">import</span> <span class="nn">ast</span>
|
||||
<span class="kn">import</span> <span class="nn">math</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Any</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
<span class="kn">import</span> <span class="nn">random</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">Transition</span><span class="p">,</span> <span class="n">Episode</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.memory</span> <span class="k">import</span> <span class="n">Memory</span><span class="p">,</span> <span class="n">MemoryGranularity</span><span class="p">,</span> <span class="n">MemoryParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">ReaderWriterLock</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.utils</span> <span class="k">import</span> <span class="n">ReaderWriterLock</span><span class="p">,</span> <span class="n">ProgressBar</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">CsvDataset</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">EpisodicExperienceReplayParameters</span><span class="p">(</span><span class="n">MemoryParameters</span><span class="p">):</span>
|
||||
@@ -208,6 +214,7 @@
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">n_step</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># for OPE we'll want a value < 1</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
@@ -220,7 +227,9 @@
|
||||
<span class="sd"> calculations of total return and other values that depend on the sequential behavior of the transitions</span>
|
||||
<span class="sd"> in the episode.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">MemoryGranularity</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span><span class="o">=</span><span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">),</span> <span class="n">n_step</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">MemoryGranularity</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">,</span> <span class="mi">1000000</span><span class="p">),</span> <span class="n">n_step</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="n">train_to_eval_ratio</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param max_size: the maximum number of transitions or episodes to hold in the memory</span>
|
||||
<span class="sd"> """</span>
|
||||
@@ -232,8 +241,11 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions_in_complete_episodes</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span> <span class="o">=</span> <span class="n">ReaderWriterLock</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># used in batch-rl</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># used in batch-rl</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">=</span> <span class="n">train_to_eval_ratio</span> <span class="c1"># used in batch-rl</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">length</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">length</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get the number of episodes in the ER (even if they are not complete)</span>
|
||||
<span class="sd"> """</span>
|
||||
@@ -255,6 +267,9 @@
|
||||
<span class="k">def</span> <span class="nf">num_transitions_in_complete_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions_in_complete_episodes</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_last_training_set_episode_id</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">is_consecutive_transitions</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Sample a batch of transitions from the replay buffer. If the requested size is larger than the number</span>
|
||||
@@ -272,7 +287,7 @@
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">transition_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">length</span><span class="p">())</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">transition_idx</span><span class="o">-</span><span class="n">size</span><span class="p">:</span><span class="n">transition_idx</span><span class="p">]</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_idx</span><span class="p">]</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">transition_idx</span> <span class="o">-</span> <span class="n">size</span><span class="p">:</span><span class="n">transition_idx</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">transitions_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">(),</span> <span class="n">size</span><span class="o">=</span><span class="n">size</span><span class="p">)</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">transitions_idx</span><span class="p">]</span>
|
||||
@@ -285,6 +300,78 @@
|
||||
|
||||
<span class="k">return</span> <span class="n">batch</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_episode_for_transition</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition</span><span class="p">:</span> <span class="n">Transition</span><span class="p">)</span> <span class="o">-></span> <span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="n">Episode</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get the episode from which that transition came from.</span>
|
||||
<span class="sd"> :param transition: The transition to lookup the episode for</span>
|
||||
<span class="sd"> :return: (Episode number, the episode) or (-1, None) if could not find a matching episode.</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">episode</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">transition</span> <span class="ow">in</span> <span class="n">episode</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">i</span><span class="p">,</span> <span class="n">episode</span>
|
||||
<span class="k">return</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="kc">None</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">shuffle_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Shuffle all the episodes in the replay buffer</span>
|
||||
<span class="sd"> :return:</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">transitions</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="o">.</span><span class="n">transitions</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_shuffled_data_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.</span>
|
||||
<span class="sd"> If the requested size is larger than the number of samples available in the replay buffer then the batch will</span>
|
||||
<span class="sd"> return empty. The last returned batch may be smaller than the size requested, to accommodate for all the</span>
|
||||
<span class="sd"> transitions in the replay buffer.</span>
|
||||
|
||||
<span class="sd"> :param size: the size of the batch to return</span>
|
||||
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o"><</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'train_to_eval_ratio should be in the (0, 1] range.'</span><span class="p">)</span>
|
||||
|
||||
<span class="n">transition</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="nb">round</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_to_eval_ratio</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">())]</span>
|
||||
<span class="n">episode_num</span><span class="p">,</span> <span class="n">episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_episode_for_transition</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">=</span> <span class="n">episode_num</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="o">=</span> \
|
||||
<span class="nb">len</span><span class="p">([</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="p">])</span>
|
||||
|
||||
<span class="n">shuffled_transition_indices</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span><span class="p">))</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># The last batch drawn will usually be < batch_size (=the size variable)</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span> <span class="o">/</span> <span class="n">size</span><span class="p">)):</span>
|
||||
<span class="n">sample_data</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">shuffled_transition_indices</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span><span class="p">:</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">size</span><span class="p">]]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
|
||||
<span class="k">yield</span> <span class="n">sample_data</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_all_complete_episodes_transitions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get all the transitions from all the complete episodes in the buffer</span>
|
||||
<span class="sd"> :return: a list of transitions</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">num_transitions_in_complete_episodes</span><span class="p">()]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_all_complete_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get all the transitions from all the complete episodes in the buffer</span>
|
||||
<span class="sd"> :return: a list of transitions</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_complete_episodes</span><span class="p">())</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_episode_id</span><span class="p">,</span> <span class="n">end_episode_id</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get all the transitions from all the complete episodes in the buffer matching the given episode range</span>
|
||||
<span class="sd"> :return: a list of transitions</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">start_episode_id</span><span class="p">:</span><span class="n">end_episode_id</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_enforce_max_length</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Make sure that the size of the replay buffer does not pass the maximum size allowed.</span>
|
||||
@@ -368,7 +455,7 @@
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">store_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode</span><span class="p">:</span> <span class="n">Episode</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">store_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode</span><span class="p">:</span> <span class="n">Episode</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Store a new episode in the memory.</span>
|
||||
<span class="sd"> :param episode: the new episode to store</span>
|
||||
@@ -391,7 +478,7 @@
|
||||
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="k">def</span> <span class="nf">get_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Returns the episode in the given index. If the episode does not exist, returns None instead.</span>
|
||||
<span class="sd"> :param episode_index: the index of the episode to return</span>
|
||||
@@ -436,7 +523,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># for API compatibility</span>
|
||||
<span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Returns the episode in the given index. If the episode does not exist, returns None instead.</span>
|
||||
<span class="sd"> :param episode_index: the index of the episode to return</span>
|
||||
@@ -494,7 +581,51 @@
|
||||
<span class="n">mean</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">([</span><span class="n">transition</span><span class="o">.</span><span class="n">reward</span> <span class="k">for</span> <span class="n">transition</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">])</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="n">mean</span></div>
|
||||
<span class="k">return</span> <span class="n">mean</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">load_csv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">csv_dataset</span><span class="p">:</span> <span class="n">CsvDataset</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Restore the replay buffer contents from a csv file.</span>
|
||||
<span class="sd"> The csv file is assumed to include a list of transitions.</span>
|
||||
<span class="sd"> :param csv_dataset: A construct which holds the dataset parameters</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">csv_dataset</span><span class="o">.</span><span class="n">filepath</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">)</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Warning! The number of transitions to load into the replay buffer (</span><span class="si">{}</span><span class="s2">) is "</span>
|
||||
<span class="s2">"bigger than the max size of the replay buffer (</span><span class="si">{}</span><span class="s2">). The excessive transitions will "</span>
|
||||
<span class="s2">"not be stored."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
|
||||
|
||||
<span class="n">episode_ids</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s1">'episode_id'</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">()</span>
|
||||
<span class="n">progress_bar</span> <span class="o">=</span> <span class="n">ProgressBar</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episode_ids</span><span class="p">))</span>
|
||||
<span class="n">state_columns</span> <span class="o">=</span> <span class="p">[</span><span class="n">col</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">df</span><span class="o">.</span><span class="n">columns</span> <span class="k">if</span> <span class="n">col</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'state_feature'</span><span class="p">)]</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">e_id</span> <span class="ow">in</span> <span class="n">episode_ids</span><span class="p">:</span>
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">e_id</span><span class="p">)</span>
|
||||
<span class="n">df_episode_transitions</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">df</span><span class="p">[</span><span class="s1">'episode_id'</span><span class="p">]</span> <span class="o">==</span> <span class="n">e_id</span><span class="p">]</span>
|
||||
<span class="n">episode</span> <span class="o">=</span> <span class="n">Episode</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">current_transition</span><span class="p">),</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">next_transition</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">df_episode_transitions</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">(),</span>
|
||||
<span class="n">df_episode_transitions</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">()):</span>
|
||||
<span class="n">state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">current_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
|
||||
<span class="n">next_state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">next_transition</span><span class="p">[</span><span class="n">col</span><span class="p">]</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">state_columns</span><span class="p">])</span>
|
||||
|
||||
<span class="n">episode</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span>
|
||||
<span class="n">Transition</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">state</span><span class="p">},</span>
|
||||
<span class="n">action</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'action'</span><span class="p">],</span> <span class="n">reward</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'reward'</span><span class="p">],</span>
|
||||
<span class="n">next_state</span><span class="o">=</span><span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">next_state</span><span class="p">},</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">info</span><span class="o">=</span><span class="p">{</span><span class="s1">'all_action_probabilities'</span><span class="p">:</span>
|
||||
<span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'all_action_probabilities'</span><span class="p">])}))</span>
|
||||
|
||||
<span class="c1"># Set the last transition to end the episode</span>
|
||||
<span class="k">if</span> <span class="n">csv_dataset</span><span class="o">.</span><span class="n">is_episodic</span><span class="p">:</span>
|
||||
<span class="n">episode</span><span class="o">.</span><span class="n">get_last_transition</span><span class="p">()</span><span class="o">.</span><span class="n">game_over</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">store_episode</span><span class="p">(</span><span class="n">episode</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># close the progress bar</span>
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episode_ids</span><span class="p">))</span>
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">shuffle_episodes</span><span class="p">()</span></div>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -194,10 +194,10 @@
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Any</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
|
||||
<span class="kn">import</span> <span class="nn">pickle</span>
|
||||
<span class="kn">import</span> <span class="nn">sys</span>
|
||||
<span class="kn">import</span> <span class="nn">time</span>
|
||||
<span class="kn">import</span> <span class="nn">random</span>
|
||||
<span class="kn">import</span> <span class="nn">math</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
@@ -252,7 +252,6 @@
|
||||
<span class="sd"> Sample a batch of transitions form the replay buffer. If the requested size is larger than the number</span>
|
||||
<span class="sd"> of samples available in the replay buffer then the batch will return empty.</span>
|
||||
<span class="sd"> :param size: the size of the batch to sample</span>
|
||||
<span class="sd"> :param beta: the beta parameter used for importance sampling</span>
|
||||
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
|
||||
@@ -272,6 +271,28 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="n">batch</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_shuffled_data_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.</span>
|
||||
<span class="sd"> If the requested size is larger than the number of samples available in the replay buffer then the batch will</span>
|
||||
<span class="sd"> return empty. The last returned batch may be smaller than the size requested, to accommodate for all the</span>
|
||||
<span class="sd"> transitions in the replay buffer.</span>
|
||||
|
||||
<span class="sd"> :param size: the size of the batch to return</span>
|
||||
<span class="sd"> :return: a batch (list) of selected transitions from the replay buffer</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
|
||||
<span class="n">shuffled_transition_indices</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">)))</span>
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># we deliberately drop some of the ending data which is left after dividing to batches of size `size`</span>
|
||||
<span class="c1"># for i in range(math.ceil(len(shuffled_transition_indices) / size)):</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">shuffled_transition_indices</span><span class="p">)</span> <span class="o">/</span> <span class="n">size</span><span class="p">)):</span>
|
||||
<span class="n">sample_data</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">shuffled_transition_indices</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span><span class="p">:</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">size</span><span class="p">]]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
|
||||
<span class="k">yield</span> <span class="n">sample_data</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_enforce_max_length</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Make sure that the size of the replay buffer does not pass the maximum size allowed.</span>
|
||||
@@ -395,7 +416,7 @@
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s1">'wb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">file</span><span class="p">:</span>
|
||||
<span class="n">pickle</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">,</span> <span class="n">file</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">load_pickled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Restore the replay buffer contents from a pickle file.</span>
|
||||
<span class="sd"> The pickle file is assumed to include a list of transitions.</span>
|
||||
@@ -418,6 +439,7 @@
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">transition_idx</span><span class="p">)</span>
|
||||
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span></div>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -298,7 +298,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">s3_access_key</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'ACCESS_KEY_ID'</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">s3_secret_key</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'SECRET_ACCESS_KEY'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crd</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Deploys the memory backend and data stores if required.</span>
|
||||
<span class="sd"> """</span>
|
||||
@@ -308,6 +308,9 @@
|
||||
<span class="k">return</span> <span class="kc">False</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">nfs_pvc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">get_info</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># Upload checkpoints in checkpoint_restore_dir (if provided) to the data store</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">setup_checkpoint_dir</span><span class="p">(</span><span class="n">crd</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="kc">True</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">deploy_trainer</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
@@ -321,7 +324,6 @@
|
||||
|
||||
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--memory_backend_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">memory_backend_parameters</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--data_store_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
|
||||
<span class="n">name</span> <span class="o">=</span> <span class="s2">"</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
@@ -346,7 +348,7 @@
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"nfs-pvc"</span><span class="p">,</span>
|
||||
<span class="n">persistent_volume_claim</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nfs_pvc</span>
|
||||
<span class="p">)],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'OnFailure'</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
@@ -365,7 +367,7 @@
|
||||
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">'app'</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
|
||||
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
|
||||
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'OnFailure'</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
@@ -427,7 +429,7 @@
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"nfs-pvc"</span><span class="p">,</span>
|
||||
<span class="n">persistent_volume_claim</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nfs_pvc</span>
|
||||
<span class="p">)],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'OnFailure'</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
@@ -446,7 +448,7 @@
|
||||
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">'app'</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
|
||||
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
|
||||
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'OnFailure'</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
@@ -496,7 +498,7 @@
|
||||
<span class="k">return</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">pod</span> <span class="ow">in</span> <span class="n">pods</span><span class="o">.</span><span class="n">items</span><span class="p">:</span>
|
||||
<span class="n">Process</span><span class="p">(</span><span class="n">target</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tail_log_file</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">namespace</span><span class="p">,</span> <span class="n">path</span><span class="p">))</span><span class="o">.</span><span class="n">start</span><span class="p">()</span>
|
||||
<span class="n">Process</span><span class="p">(</span><span class="n">target</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tail_log_file</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">namespace</span><span class="p">,</span> <span class="n">path</span><span class="p">),</span> <span class="n">daemon</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span><span class="o">.</span><span class="n">start</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_tail_log_file</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pod_name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">,</span> <span class="n">namespace</span><span class="p">,</span> <span class="n">path</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
|
||||
@@ -528,7 +530,7 @@
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">pod</span><span class="p">:</span>
|
||||
<span class="k">return</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tail_log</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tail_log</span><span class="p">(</span><span class="n">pod</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">api_client</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">tail_log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pod_name</span><span class="p">,</span> <span class="n">corev1_api</span><span class="p">):</span>
|
||||
<span class="k">while</span> <span class="kc">True</span><span class="p">:</span>
|
||||
@@ -562,9 +564,9 @@
|
||||
<span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">waiting</span><span class="o">.</span><span class="n">reason</span> <span class="o">==</span> <span class="s1">'CrashLoopBackOff'</span> <span class="ow">or</span> \
|
||||
<span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">waiting</span><span class="o">.</span><span class="n">reason</span> <span class="o">==</span> <span class="s1">'ImagePullBackOff'</span> <span class="ow">or</span> \
|
||||
<span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">waiting</span><span class="o">.</span><span class="n">reason</span> <span class="o">==</span> <span class="s1">'ErrImagePull'</span><span class="p">:</span>
|
||||
<span class="k">return</span>
|
||||
<span class="k">return</span> <span class="mi">1</span>
|
||||
<span class="k">if</span> <span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">terminated</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">return</span>
|
||||
<span class="k">return</span> <span class="n">container_status</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">terminated</span><span class="o">.</span><span class="n">exit_code</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">undeploy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
@@ -828,7 +828,7 @@
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">state</span><span class="p">:</span> <span class="n">StateSpace</span><span class="p">,</span>
|
||||
<span class="n">goal</span><span class="p">:</span> <span class="n">ObservationSpace</span><span class="p">,</span>
|
||||
<span class="n">goal</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">ObservationSpace</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span>
|
||||
<span class="n">action</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">,</span>
|
||||
<span class="n">reward</span><span class="p">:</span> <span class="n">RewardSpace</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">state</span> <span class="o">=</span> <span class="n">state</span>
|
||||
|
||||
@@ -21,6 +21,7 @@ A detailed description of those algorithms can be found by navigating to each of
|
||||
imitation/cil
|
||||
policy_optimization/cppo
|
||||
policy_optimization/ddpg
|
||||
policy_optimization/sac
|
||||
other/dfp
|
||||
value_optimization/double_dqn
|
||||
value_optimization/dqn
|
||||
|
||||
@@ -38,6 +38,7 @@ Each update perform the following procedure:
|
||||
.. math:: \text{where} \quad \bar{\rho}_{t} = \min{\left\{c,\rho_t\right\}},\quad \rho_t=\frac{\pi (a_t \mid s_t)}{\mu (a_t \mid s_t)}
|
||||
|
||||
3. **Accumulate gradients:**
|
||||
|
||||
:math:`\bullet` **Policy gradients (with bias correction):**
|
||||
|
||||
.. math:: \hat{g}_t^{policy} & = & \bar{\rho}_{t} \nabla \log \pi (a_t \mid s_t) [Q^{ret}(s_t,a_t) - V(s_t)] \\
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
Soft Actor-Critic
|
||||
============
|
||||
|
||||
**Actions space:** Continuous
|
||||
|
||||
**References:** `Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor <https://arxiv.org/abs/1801.01290>`_
|
||||
|
||||
Network Structure
|
||||
-----------------
|
||||
|
||||
.. image:: /_static/img/design_imgs/sac.png
|
||||
:align: center
|
||||
|
||||
Algorithm Description
|
||||
---------------------
|
||||
|
||||
Choosing an action - Continuous actions
|
||||
+++++++++++++++++++++++++++++++++++++
|
||||
|
||||
The policy network is used in order to predict mean and log std for each action. While training, a sample is taken
|
||||
from a Gaussian distribution with these mean and std values. When testing, the agent can choose deterministically
|
||||
by picking the mean value or sample from a gaussian distribution like in training.
|
||||
|
||||
Training the network
|
||||
++++++++++++++++++++
|
||||
Start by sampling a batch :math:`B` of transitions from the experience replay.
|
||||
|
||||
* To train the **Q network**, use the following targets:
|
||||
|
||||
.. math:: y_t^Q=r(s_t,a_t)+\gamma \cdot V(s_{t+1})
|
||||
|
||||
The state value used in the above target is acquired by running the target state value network.
|
||||
|
||||
* To train the **State Value network**, use the following targets:
|
||||
|
||||
.. math:: y_t^V = \min_{i=1,2}Q_i(s_t,\tilde{a}) - log\pi (\tilde{a} \vert s),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)
|
||||
|
||||
The state value network is trained using a sample-based approximation of the connection between and state value and state
|
||||
action values, The actions used for constructing the target are **not** sampled from the replay buffer, but rather sampled
|
||||
from the current policy.
|
||||
|
||||
* To train the **actor network**, use the following equation:
|
||||
|
||||
.. math:: \nabla_{\theta} J \approx \nabla_{\theta} \frac{1}{\vert B \vert} \sum_{s_t\in B} \left( Q \left(s_t, \tilde{a}_\theta(s_t)\right) - log\pi_{\theta}(\tilde{a}_{\theta}(s_t)\vert s_t) \right),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)
|
||||
|
||||
After every training step, do a soft update of the V target network's weights from the online networks.
|
||||
|
||||
|
||||
.. autoclass:: rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters
|
||||
@@ -198,6 +198,14 @@ The algorithms are ordered by their release date in descending order.
|
||||
improve stability it also employs bias correction and trust region optimization techniques.
|
||||
</span>
|
||||
</div>
|
||||
<div class="algorithm continuous off-policy" data-year="201808">
|
||||
<span class="badge">
|
||||
<a href="components/agents/policy_optimization/sac.html">SAC</a>
|
||||
<br>
|
||||
Soft Actor-Critic is an algorithm which optimizes a stochastic policy in an off-policy way.
|
||||
One of the key features of SAC is that it solves a maximum entropy reinforcement learning problem.
|
||||
</span>
|
||||
</div>
|
||||
<div class="algorithm continuous off-policy" data-year="201509">
|
||||
<span class="badge">
|
||||
<a href="components/agents/policy_optimization/ddpg.html">DDPG</a>
|
||||
|
||||
@@ -113,7 +113,8 @@ In Coach, this can be done in two steps -
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=\"<experiment dir>/replay_buffer.p\"'
|
||||
from rl_coach.core_types import PickledReplayBuffer
|
||||
coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=PickledReplayBuffer(\"<experiment dir>/replay_buffer.p\"')
|
||||
|
||||
|
||||
Visualizations
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
/* Docs background */
|
||||
.wy-side-nav-search{
|
||||
background-color: #043c74;
|
||||
}
|
||||
|
||||
/* Mobile version */
|
||||
.wy-nav-top{
|
||||
background-color: #043c74;
|
||||
}
|
||||
|
||||
|
||||
.green {
|
||||
color: green;
|
||||
}
|
||||
|
||||
.red {
|
||||
color: red;
|
||||
}
|
||||
|
||||
.blue {
|
||||
color: blue;
|
||||
}
|
||||
|
||||
.yellow {
|
||||
color: yellow;
|
||||
}
|
||||
|
||||
.badge {
|
||||
border: 2px;
|
||||
border-style: solid;
|
||||
border-color: #6C8EBF;
|
||||
border-radius: 5px;
|
||||
padding: 3px 15px 3px 15px;
|
||||
margin: 5px;
|
||||
display: inline-block;
|
||||
font-weight: bold;
|
||||
font-size: 16px;
|
||||
background: #DAE8FC;
|
||||
}
|
||||
|
||||
.badge:hover {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.badge > a {
|
||||
color: black;
|
||||
}
|
||||
|
||||
.bordered-container {
|
||||
border: 0px;
|
||||
border-style: solid;
|
||||
border-radius: 8px;
|
||||
padding: 15px;
|
||||
margin-bottom: 20px;
|
||||
background: #f2f2f2;
|
||||
}
|
||||
|
||||
.questionnaire {
|
||||
font-size: 1.2em;
|
||||
line-height: 1.5em;
|
||||
}
|
||||
@@ -276,19 +276,22 @@ of the trace tests suite.</li>
|
||||
<h2>TaskParameters<a class="headerlink" href="#taskparameters" title="Permalink to this headline">¶</a></h2>
|
||||
<dl class="class">
|
||||
<dt id="rl_coach.base_parameters.TaskParameters">
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">TaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks = <Frameworks.tensorflow: 'TensorFlow'></em>, <em>evaluate_only: bool = False</em>, <em>use_cpu: bool = False</em>, <em>experiment_path='/tmp'</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_dir=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em>, <em>num_gpu: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#TaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.TaskParameters" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">TaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks = <Frameworks.tensorflow: 'TensorFlow'></em>, <em>evaluate_only: int = None</em>, <em>use_cpu: bool = False</em>, <em>experiment_path='/tmp'</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_dir=None</em>, <em>checkpoint_restore_path=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em>, <em>num_gpu: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#TaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.TaskParameters" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
<tbody valign="top">
|
||||
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
|
||||
<li><strong>framework_type</strong> – deep learning framework type. currently only tensorflow is supported</li>
|
||||
<li><strong>evaluate_only</strong> – the task will be used only for evaluating the model</li>
|
||||
<li><strong>evaluate_only</strong> – if not None, the task will be used only for evaluating the model for the given number of steps.
|
||||
A value of 0 means that task will be evaluated for an infinite number of steps.</li>
|
||||
<li><strong>use_cpu</strong> – use the cpu for this task</li>
|
||||
<li><strong>experiment_path</strong> – the path to the directory which will store all the experiment outputs</li>
|
||||
<li><strong>seed</strong> – a seed to use for the random numbers generator</li>
|
||||
<li><strong>checkpoint_save_secs</strong> – the number of seconds between each checkpoint saving</li>
|
||||
<li><strong>checkpoint_restore_dir</strong> – the directory to restore the checkpoints from</li>
|
||||
<li><strong>checkpoint_restore_dir</strong> – [DEPECRATED - will be removed in one of the next releases - switch to checkpoint_restore_path]
|
||||
the dir to restore the checkpoints from</li>
|
||||
<li><strong>checkpoint_restore_path</strong> – the path to restore the checkpoints from</li>
|
||||
<li><strong>checkpoint_save_dir</strong> – the directory to store the checkpoints in</li>
|
||||
<li><strong>export_onnx_graph</strong> – If set to True, this will export an onnx graph each time a checkpoint is saved</li>
|
||||
<li><strong>apply_stop_condition</strong> – If set to True, this will apply the stop condition defined by reaching a target success rate</li>
|
||||
@@ -305,14 +308,15 @@ of the trace tests suite.</li>
|
||||
<h2>DistributedTaskParameters<a class="headerlink" href="#distributedtaskparameters" title="Permalink to this headline">¶</a></h2>
|
||||
<dl class="class">
|
||||
<dt id="rl_coach.base_parameters.DistributedTaskParameters">
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">DistributedTaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks</em>, <em>parameters_server_hosts: str</em>, <em>worker_hosts: str</em>, <em>job_type: str</em>, <em>task_index: int</em>, <em>evaluate_only: bool = False</em>, <em>num_tasks: int = None</em>, <em>num_training_tasks: int = None</em>, <em>use_cpu: bool = False</em>, <em>experiment_path=None</em>, <em>dnd=None</em>, <em>shared_memory_scratchpad=None</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_dir=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#DistributedTaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.DistributedTaskParameters" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">DistributedTaskParameters</code><span class="sig-paren">(</span><em>framework_type: rl_coach.base_parameters.Frameworks</em>, <em>parameters_server_hosts: str</em>, <em>worker_hosts: str</em>, <em>job_type: str</em>, <em>task_index: int</em>, <em>evaluate_only: int = None</em>, <em>num_tasks: int = None</em>, <em>num_training_tasks: int = None</em>, <em>use_cpu: bool = False</em>, <em>experiment_path=None</em>, <em>dnd=None</em>, <em>shared_memory_scratchpad=None</em>, <em>seed=None</em>, <em>checkpoint_save_secs=None</em>, <em>checkpoint_restore_path=None</em>, <em>checkpoint_save_dir=None</em>, <em>export_onnx_graph: bool = False</em>, <em>apply_stop_condition: bool = False</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/base_parameters.html#DistributedTaskParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.DistributedTaskParameters" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
<tbody valign="top">
|
||||
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
|
||||
<li><strong>framework_type</strong> – deep learning framework type. currently only tensorflow is supported</li>
|
||||
<li><strong>evaluate_only</strong> – the task will be used only for evaluating the model</li>
|
||||
<li><strong>evaluate_only</strong> – if not None, the task will be used only for evaluating the model for the given number of steps.
|
||||
A value of 0 means that task will be evaluated for an infinite number of steps.</li>
|
||||
<li><strong>parameters_server_hosts</strong> – comma-separated list of hostname:port pairs to which the parameter servers are
|
||||
assigned</li>
|
||||
<li><strong>worker_hosts</strong> – comma-separated list of hostname:port pairs to which the workers are assigned</li>
|
||||
@@ -325,7 +329,7 @@ assigned</li>
|
||||
<li><strong>dnd</strong> – an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.</li>
|
||||
<li><strong>seed</strong> – a seed to use for the random numbers generator</li>
|
||||
<li><strong>checkpoint_save_secs</strong> – the number of seconds between each checkpoint saving</li>
|
||||
<li><strong>checkpoint_restore_dir</strong> – the directory to restore the checkpoints from</li>
|
||||
<li><strong>checkpoint_restore_path</strong> – the path to restore the checkpoints from</li>
|
||||
<li><strong>checkpoint_save_dir</strong> – the directory to store the checkpoints in</li>
|
||||
<li><strong>export_onnx_graph</strong> – If set to True, this will export an onnx graph each time a checkpoint is saved</li>
|
||||
<li><strong>apply_stop_condition</strong> – If set to True, this will apply the stop condition defined by reaching a target success rate</li>
|
||||
|
||||
@@ -121,6 +121,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -121,6 +121,7 @@
|
||||
</li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
@@ -221,6 +222,7 @@ A detailed description of those algorithms can be found by navigating to each of
|
||||
<li class="toctree-l1"><a class="reference internal" href="imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
@@ -280,13 +282,15 @@ used for visualization purposes, such as printing to the screen, rendering, and
|
||||
</table>
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.agent.Agent.act">
|
||||
<code class="descname">act</code><span class="sig-paren">(</span><span class="sig-paren">)</span> → rl_coach.core_types.ActionInfo<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.act"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.act" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="descname">act</code><span class="sig-paren">(</span><em>action: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List] = None</em><span class="sig-paren">)</span> → rl_coach.core_types.ActionInfo<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.act"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.act" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Given the agents current knowledge, decide on the next action to apply to the environment</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
<tbody valign="top">
|
||||
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
|
||||
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>action</strong> – An action to take, overriding whatever the current policy is</td>
|
||||
</tr>
|
||||
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -357,26 +361,6 @@ for creating the network.</p>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.agent.Agent.emulate_act_on_trainer">
|
||||
<code class="descname">emulate_act_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> → rl_coach.core_types.ActionInfo<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.emulate_act_on_trainer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.emulate_act_on_trainer" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Given the agents current knowledge, decide on the next action to apply to the environment
|
||||
:return: an action and a dictionary containing any additional info from the action decision process</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.agent.Agent.emulate_observe_on_trainer">
|
||||
<code class="descname">emulate_observe_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> → bool<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.emulate_observe_on_trainer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.emulate_observe_on_trainer" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>This emulates the observe using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Given a response from the environment, distill the observation from it and store it for later use.
|
||||
The response should be a dictionary containing the performed action, the new observation and measurements,
|
||||
the reward, a game over flag and any additional information necessary.
|
||||
:return:</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.agent.Agent.get_predictions">
|
||||
<code class="descname">get_predictions</code><span class="sig-paren">(</span><em>states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.get_predictions"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.get_predictions" title="Permalink to this definition">¶</a></dt>
|
||||
@@ -540,7 +524,7 @@ given observation</td>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.agent.Agent.prepare_batch_for_inference">
|
||||
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> → Dict[str, numpy.core.multiarray.array]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.prepare_batch_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.prepare_batch_for_inference" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> → Dict[str, numpy.array]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.prepare_batch_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.prepare_batch_for_inference" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
|
||||
observations together, measurements together, etc.</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
@@ -632,6 +616,21 @@ by val, and by the current phase set in self.phase.</p>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.agent.Agent.run_off_policy_evaluation">
|
||||
<code class="descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span> → None<a class="headerlink" href="#rl_coach.agents.agent.Agent.run_off_policy_evaluation" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Run off-policy evaluation estimators to evaluate the trained policy performance against a dataset.
|
||||
Should only be implemented for off-policy RL algorithms.</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
<tbody valign="top">
|
||||
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">None</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference">
|
||||
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> → Dict[str, numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/agents/agent.html#Agent.run_pre_network_filter_for_inference"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference" title="Permalink to this definition">¶</a></dt>
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link rel="next" title="Double DQN" href="../value_optimization/double_dqn.html" />
|
||||
<link rel="prev" title="Deep Deterministic Policy Gradient" href="../policy_optimization/ddpg.html" />
|
||||
<link rel="prev" title="Soft Actor-Critic" href="../policy_optimization/sac.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Direct Future Prediction</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#algorithm-description">Algorithm Description</a><ul>
|
||||
@@ -296,7 +297,7 @@ have a different scale and you want to normalize them to the same scale.</li>
|
||||
<a href="../value_optimization/double_dqn.html" class="btn btn-neutral float-right" title="Double DQN" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
|
||||
|
||||
|
||||
<a href="../policy_optimization/ddpg.html" class="btn btn-neutral" title="Deep Deterministic Policy Gradient" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
|
||||
<a href="../policy_optimization/sac.html" class="btn btn-neutral" title="Soft Actor-Critic" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@@ -122,6 +122,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -122,6 +122,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
@@ -247,21 +248,20 @@ and <span class="math notranslate nohighlight">\(n\)</span> (replay ratio) off-p
|
||||
\[\text{where} \quad \bar{\rho}_{t} = \min{\left\{c,\rho_t\right\}},\quad \rho_t=\frac{\pi (a_t \mid s_t)}{\mu (a_t \mid s_t)}\]</div>
|
||||
</div></blockquote>
|
||||
</li>
|
||||
<li><dl class="first docutils">
|
||||
<dt><strong>Accumulate gradients:</strong></dt>
|
||||
<dd><p class="first"><span class="math notranslate nohighlight">\(\bullet\)</span> <strong>Policy gradients (with bias correction):</strong></p>
|
||||
<li><p class="first"><strong>Accumulate gradients:</strong></p>
|
||||
<blockquote>
|
||||
<div><p><span class="math notranslate nohighlight">\(\bullet\)</span> <strong>Policy gradients (with bias correction):</strong></p>
|
||||
<blockquote>
|
||||
<div><div class="math notranslate nohighlight">
|
||||
\[\begin{split}\hat{g}_t^{policy} & = & \bar{\rho}_{t} \nabla \log \pi (a_t \mid s_t) [Q^{ret}(s_t,a_t) - V(s_t)] \\
|
||||
& & + \mathbb{E}_{a \sim \pi} \left(\left[\frac{\rho_t(a)-c}{\rho_t(a)}\right] \nabla \log \pi (a \mid s_t) [Q(s_t,a) - V(s_t)] \right)\end{split}\]</div>
|
||||
</div></blockquote>
|
||||
<p><span class="math notranslate nohighlight">\(\bullet\)</span> <strong>Q-Head gradients (MSE):</strong></p>
|
||||
<blockquote class="last">
|
||||
<blockquote>
|
||||
<div><div class="math notranslate nohighlight">
|
||||
\[\begin{split}\hat{g}_t^{Q} = (Q^{ret}(s_t,a_t) - Q(s_t,a_t)) \nabla Q(s_t,a_t)\\\end{split}\]</div>
|
||||
</div></blockquote>
|
||||
</dd>
|
||||
</dl>
|
||||
</div></blockquote>
|
||||
</li>
|
||||
<li><p class="first"><strong>(Optional) Trust region update:</strong> change the policy loss gradient w.r.t network output:</p>
|
||||
<blockquote>
|
||||
|
||||
@@ -122,6 +122,7 @@
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link rel="next" title="Direct Future Prediction" href="../other/dfp.html" />
|
||||
<link rel="next" title="Soft Actor-Critic" href="sac.html" />
|
||||
<link rel="prev" title="Clipped Proximal Policy Optimization" href="cppo.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
@@ -122,6 +122,7 @@
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
@@ -297,7 +298,7 @@ values. If set to False, the terminal states reward will be taken as the target
|
||||
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
|
||||
<a href="../other/dfp.html" class="btn btn-neutral float-right" title="Direct Future Prediction" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
|
||||
<a href="sac.html" class="btn btn-neutral float-right" title="Soft Actor-Critic" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
|
||||
|
||||
|
||||
<a href="cppo.html" class="btn btn-neutral" title="Clipped Proximal Policy Optimization" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -0,0 +1,343 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>Soft Actor-Critic — Reinforcement Learning Coach 0.11.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link rel="next" title="Direct Future Prediction" href="../other/dfp.html" />
|
||||
<link rel="prev" title="Deep Deterministic Policy Gradient" href="ddpg.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
|
||||
|
||||
<script src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search">
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Agents</a><ul class="current">
|
||||
<li class="toctree-l2"><a class="reference internal" href="ac.html">Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="acer.html">ACER</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/bc.html">Behavioral Cloning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/bs_dqn.html">Bootstrapped DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/categorical_dqn.html">Categorical DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Soft Actor-Critic</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#algorithm-description">Algorithm Description</a><ul>
|
||||
<li class="toctree-l4"><a class="reference internal" href="#choosing-an-action-continuous-actions">Choosing an action - Continuous actions</a></li>
|
||||
<li class="toctree-l4"><a class="reference internal" href="#training-the-network">Training the network</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dqn.html">Deep Q Networks</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/dueling_dqn.html">Dueling DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/mmc.html">Mixed Monte Carlo</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/n_step.html">N-Step Q Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/naf.html">Normalized Advantage Functions</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/nec.html">Neural Episodic Control</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/pal.html">Persistent Advantage Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="pg.html">Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="ppo.html">Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/rainbow.html">Rainbow</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../value_optimization/qr_dqn.html">Quantile Regression DQN</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../data_stores/index.html">Data Stores</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../memory_backends/index.html">Memory Backends</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../orchestrators/index.html">Orchestrators</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../index.html">Agents</a> »</li>
|
||||
|
||||
<li>Soft Actor-Critic</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../../../_sources/components/agents/policy_optimization/sac.rst.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<div class="section" id="soft-actor-critic">
|
||||
<h1>Soft Actor-Critic<a class="headerlink" href="#soft-actor-critic" title="Permalink to this headline">¶</a></h1>
|
||||
<p><strong>Actions space:</strong> Continuous</p>
|
||||
<p><strong>References:</strong> <a class="reference external" href="https://arxiv.org/abs/1801.01290">Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor</a></p>
|
||||
<div class="section" id="network-structure">
|
||||
<h2>Network Structure<a class="headerlink" href="#network-structure" title="Permalink to this headline">¶</a></h2>
|
||||
<img alt="../../../_images/sac.png" class="align-center" src="../../../_images/sac.png" />
|
||||
</div>
|
||||
<div class="section" id="algorithm-description">
|
||||
<h2>Algorithm Description<a class="headerlink" href="#algorithm-description" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="section" id="choosing-an-action-continuous-actions">
|
||||
<h3>Choosing an action - Continuous actions<a class="headerlink" href="#choosing-an-action-continuous-actions" title="Permalink to this headline">¶</a></h3>
|
||||
<p>The policy network is used in order to predict mean and log std for each action. While training, a sample is taken
|
||||
from a Gaussian distribution with these mean and std values. When testing, the agent can choose deterministically
|
||||
by picking the mean value or sample from a gaussian distribution like in training.</p>
|
||||
</div>
|
||||
<div class="section" id="training-the-network">
|
||||
<h3>Training the network<a class="headerlink" href="#training-the-network" title="Permalink to this headline">¶</a></h3>
|
||||
<p>Start by sampling a batch <span class="math notranslate nohighlight">\(B\)</span> of transitions from the experience replay.</p>
|
||||
<ul>
|
||||
<li><p class="first">To train the <strong>Q network</strong>, use the following targets:</p>
|
||||
<div class="math notranslate nohighlight">
|
||||
\[y_t^Q=r(s_t,a_t)+\gamma \cdot V(s_{t+1})\]</div>
|
||||
<p>The state value used in the above target is acquired by running the target state value network.</p>
|
||||
</li>
|
||||
<li><p class="first">To train the <strong>State Value network</strong>, use the following targets:</p>
|
||||
<div class="math notranslate nohighlight">
|
||||
\[y_t^V = \min_{i=1,2}Q_i(s_t,\tilde{a}) - log\pi (\tilde{a} \vert s),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)\]</div>
|
||||
<p>The state value network is trained using a sample-based approximation of the connection between and state value and state
|
||||
action values, The actions used for constructing the target are <strong>not</strong> sampled from the replay buffer, but rather sampled
|
||||
from the current policy.</p>
|
||||
</li>
|
||||
<li><p class="first">To train the <strong>actor network</strong>, use the following equation:</p>
|
||||
<div class="math notranslate nohighlight">
|
||||
\[\nabla_{\theta} J \approx \nabla_{\theta} \frac{1}{\vert B \vert} \sum_{s_t\in B} \left( Q \left(s_t, \tilde{a}_\theta(s_t)\right) - log\pi_{\theta}(\tilde{a}_{\theta}(s_t)\vert s_t) \right),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)\]</div>
|
||||
</li>
|
||||
</ul>
|
||||
<p>After every training step, do a soft update of the V target network’s weights from the online networks.</p>
|
||||
<dl class="class">
|
||||
<dt id="rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters">
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.agents.soft_actor_critic_agent.</code><code class="descname">SoftActorCriticAlgorithmParameters</code><a class="reference internal" href="../../../_modules/rl_coach/agents/soft_actor_critic_agent.html#SoftActorCriticAlgorithmParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
<tbody valign="top">
|
||||
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
|
||||
<li><strong>num_steps_between_copying_online_weights_to_target</strong> – (StepMethod)
|
||||
The number of steps between copying the online network weights to the target network weights.</li>
|
||||
<li><strong>rate_for_copying_weights_to_target</strong> – (float)
|
||||
When copying the online network weights to the target network weights, a soft update will be used, which
|
||||
weight the new online network weights by rate_for_copying_weights_to_target. (Tau as defined in the paper)</li>
|
||||
<li><strong>use_deterministic_for_evaluation</strong> – (bool)
|
||||
If True, during the evaluation phase, action are chosen deterministically according to the policy mean
|
||||
and not sampled from the policy distribution.</li>
|
||||
</ul>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
|
||||
<a href="../other/dfp.html" class="btn btn-neutral float-right" title="Direct Future Prediction" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
|
||||
|
||||
|
||||
<a href="ddpg.html" class="btn btn-neutral" title="Deep Deterministic Policy Gradient" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/language_data.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
@@ -123,6 +123,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -121,6 +121,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Double DQN</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#network-structure">Network Structure</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Deep Q Networks</a><ul>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -114,6 +114,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="../imitation/cil.html">Conditional Imitation Learning</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/cppo.html">Clipped Proximal Policy Optimization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/ddpg.html">Deep Deterministic Policy Gradient</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../policy_optimization/sac.html">Soft Actor-Critic</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="../other/dfp.html">Direct Future Prediction</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="double_dqn.html">Double DQN</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="dqn.html">Deep Q Networks</a></li>
|
||||
|
||||
@@ -193,7 +193,7 @@ own components under a dedicated directory. For example, tensorflow components w
|
||||
parts that are implemented using TensorFlow.</p>
|
||||
<dl class="class">
|
||||
<dt id="rl_coach.base_parameters.NetworkParameters">
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">NetworkParameters</code><span class="sig-paren">(</span><em>force_cpu=False</em>, <em>async_training=False</em>, <em>shared_optimizer=True</em>, <em>scale_down_gradients_by_number_of_workers_for_sync_training=True</em>, <em>clip_gradients=None</em>, <em>gradients_clipping_method=<GradientClippingMethod.ClipByGlobalNorm: 0></em>, <em>l2_regularization=0</em>, <em>learning_rate=0.00025</em>, <em>learning_rate_decay_rate=0</em>, <em>learning_rate_decay_steps=0</em>, <em>input_embedders_parameters={}</em>, <em>embedding_merger_type=<EmbeddingMergerType.Concat: 0></em>, <em>middleware_parameters=None</em>, <em>heads_parameters=[]</em>, <em>use_separate_networks_per_head=False</em>, <em>optimizer_type='Adam'</em>, <em>optimizer_epsilon=0.0001</em>, <em>adam_optimizer_beta1=0.9</em>, <em>adam_optimizer_beta2=0.99</em>, <em>rms_prop_optimizer_decay=0.9</em>, <em>batch_size=32</em>, <em>replace_mse_with_huber_loss=False</em>, <em>create_target_network=False</em>, <em>tensorflow_support=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#NetworkParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.NetworkParameters" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">NetworkParameters</code><span class="sig-paren">(</span><em>force_cpu=False</em>, <em>async_training=False</em>, <em>shared_optimizer=True</em>, <em>scale_down_gradients_by_number_of_workers_for_sync_training=True</em>, <em>clip_gradients=None</em>, <em>gradients_clipping_method=<GradientClippingMethod.ClipByGlobalNorm: 0></em>, <em>l2_regularization=0</em>, <em>learning_rate=0.00025</em>, <em>learning_rate_decay_rate=0</em>, <em>learning_rate_decay_steps=0</em>, <em>input_embedders_parameters={}</em>, <em>embedding_merger_type=<EmbeddingMergerType.Concat: 0></em>, <em>middleware_parameters=None</em>, <em>heads_parameters=[]</em>, <em>use_separate_networks_per_head=False</em>, <em>optimizer_type='Adam'</em>, <em>optimizer_epsilon=0.0001</em>, <em>adam_optimizer_beta1=0.9</em>, <em>adam_optimizer_beta2=0.99</em>, <em>rms_prop_optimizer_decay=0.9</em>, <em>batch_size=32</em>, <em>replace_mse_with_huber_loss=False</em>, <em>create_target_network=False</em>, <em>tensorflow_support=True</em>, <em>softmax_temperature=1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#NetworkParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.NetworkParameters" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
@@ -257,6 +257,7 @@ selected for this network.</li>
|
||||
same weights as the online network. It can then be queried, and its weights can be synced from the
|
||||
online network at will.</li>
|
||||
<li><strong>tensorflow_support</strong> – A flag which specifies if the network is supported by the TensorFlow framework.</li>
|
||||
<li><strong>softmax_temperature</strong> – If a softmax is present in the network head output, use this temperature</li>
|
||||
</ul>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
@@ -194,7 +194,7 @@
|
||||
<h2>ActionInfo<a class="headerlink" href="#actioninfo" title="Permalink to this headline">¶</a></h2>
|
||||
<dl class="class">
|
||||
<dt id="rl_coach.core_types.ActionInfo">
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.core_types.</code><code class="descname">ActionInfo</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List], all_action_probabilities: float = 0, action_value: float = 0.0, state_value: float = 0.0, max_action_value: float = None, action_intrinsic_reward: float = 0</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#ActionInfo"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.ActionInfo" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.core_types.</code><code class="descname">ActionInfo</code><span class="sig-paren">(</span><em>action: Union[int, float, numpy.ndarray, List], all_action_probabilities: float = 0, action_value: float = 0.0, state_value: float = 0.0, max_action_value: float = None</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/rl_coach/core_types.html#ActionInfo"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.core_types.ActionInfo" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Action info is a class that holds an action and various additional information details about it</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
@@ -208,8 +208,6 @@
|
||||
<li><strong>max_action_value</strong> – in case this is an action that was selected randomly, this is the value of the action
|
||||
that received the maximum value. if no value is given, the action is assumed to be the
|
||||
action with the maximum value</li>
|
||||
<li><strong>action_intrinsic_reward</strong> – can contain any intrinsic reward that the agent wants to add to this action
|
||||
selection</li>
|
||||
</ul>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
@@ -206,7 +206,7 @@
|
||||
<h3>EpisodicExperienceReplay<a class="headerlink" href="#episodicexperiencereplay" title="Permalink to this headline">¶</a></h3>
|
||||
<dl class="class">
|
||||
<dt id="rl_coach.memories.episodic.EpisodicExperienceReplay">
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.memories.episodic.</code><code class="descname">EpisodicExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity</em>, <em>int] = (<MemoryGranularity.Transitions: 0></em>, <em>1000000)</em>, <em>n_step=-1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_experience_replay.html#EpisodicExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicExperienceReplay" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.memories.episodic.</code><code class="descname">EpisodicExperienceReplay</code><span class="sig-paren">(</span><em>max_size: Tuple[rl_coach.memories.memory.MemoryGranularity</em>, <em>int] = (<MemoryGranularity.Transitions: 0></em>, <em>1000000)</em>, <em>n_step=-1</em>, <em>train_to_eval_ratio: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/memories/episodic/episodic_experience_replay.html#EpisodicExperienceReplay"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.memories.episodic.EpisodicExperienceReplay" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>A replay buffer that stores episodes of transitions. The additional structure allows performing various
|
||||
calculations of total return and other values that depend on the sequential behavior of the transitions
|
||||
in the episode.</p>
|
||||
|
||||
@@ -382,26 +382,14 @@
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="components/exploration_policies/index.html#rl_coach.exploration_policies.e_greedy.EGreedy">EGreedy (class in rl_coach.exploration_policies.e_greedy)</a>
|
||||
</li>
|
||||
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.emulate_act_on_trainer">emulate_act_on_trainer() (rl_coach.agents.agent.Agent method)</a>
|
||||
|
||||
<ul>
|
||||
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
|
||||
</li>
|
||||
</ul></li>
|
||||
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.emulate_observe_on_trainer">emulate_observe_on_trainer() (rl_coach.agents.agent.Agent method)</a>
|
||||
|
||||
<ul>
|
||||
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
|
||||
</li>
|
||||
</ul></li>
|
||||
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment">Environment (class in rl_coach.environments.environment)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="components/core_types.html#rl_coach.core_types.EnvResponse">EnvResponse (class in rl_coach.core_types)</a>
|
||||
</li>
|
||||
<li><a href="components/core_types.html#rl_coach.core_types.Episode">Episode (class in rl_coach.core_types)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="components/memories/index.html#rl_coach.memories.episodic.EpisodicExperienceReplay">EpisodicExperienceReplay (class in rl_coach.memories.episodic)</a>
|
||||
</li>
|
||||
<li><a href="components/memories/index.html#rl_coach.memories.episodic.EpisodicHindsightExperienceReplay">EpisodicHindsightExperienceReplay (class in rl_coach.memories.episodic)</a>
|
||||
@@ -503,6 +491,8 @@
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="components/spaces.html#rl_coach.spaces.ImageObservationSpace">ImageObservationSpace (class in rl_coach.spaces)</a>
|
||||
</li>
|
||||
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model">improve_reward_model() (rl_coach.agents.dqn_agent.DQNAgent method)</a>
|
||||
</li>
|
||||
<li><a href="components/core_types.html#rl_coach.core_types.Batch.info">info() (rl_coach.core_types.Batch method)</a>
|
||||
</li>
|
||||
@@ -738,8 +728,6 @@
|
||||
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.reset_evaluation_state">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
|
||||
</li>
|
||||
</ul></li>
|
||||
</ul></td>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.reset_internal_state">reset_internal_state() (rl_coach.agents.agent.Agent method)</a>
|
||||
|
||||
<ul>
|
||||
@@ -748,6 +736,8 @@
|
||||
<li><a href="components/environments/index.html#rl_coach.environments.environment.Environment.reset_internal_state">(rl_coach.environments.environment.Environment method)</a>
|
||||
</li>
|
||||
</ul></li>
|
||||
</ul></td>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.restore_checkpoint">restore_checkpoint() (rl_coach.agents.agent.Agent method)</a>
|
||||
|
||||
<ul>
|
||||
@@ -762,6 +752,12 @@
|
||||
</li>
|
||||
<li><a href="components/core_types.html#rl_coach.core_types.Batch.rewards">rewards() (rl_coach.core_types.Batch method)</a>
|
||||
</li>
|
||||
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.run_off_policy_evaluation">run_off_policy_evaluation() (rl_coach.agents.agent.Agent method)</a>
|
||||
|
||||
<ul>
|
||||
<li><a href="test.html#rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation">(rl_coach.agents.dqn_agent.DQNAgent method)</a>
|
||||
</li>
|
||||
</ul></li>
|
||||
<li><a href="components/agents/index.html#rl_coach.agents.agent.Agent.run_pre_network_filter_for_inference">run_pre_network_filter_for_inference() (rl_coach.agents.agent.Agent method)</a>
|
||||
|
||||
<ul>
|
||||
@@ -839,6 +835,8 @@
|
||||
<li><a href="components/core_types.html#rl_coach.core_types.Batch.size">size (rl_coach.core_types.Batch attribute)</a>
|
||||
</li>
|
||||
<li><a href="components/core_types.html#rl_coach.core_types.Batch.slice">slice() (rl_coach.core_types.Batch method)</a>
|
||||
</li>
|
||||
<li><a href="components/agents/policy_optimization/sac.html#rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters">SoftActorCriticAlgorithmParameters (class in rl_coach.agents.soft_actor_critic_agent)</a>
|
||||
</li>
|
||||
<li><a href="components/spaces.html#rl_coach.spaces.Space">Space (class in rl_coach.spaces)</a>
|
||||
</li>
|
||||
|
||||
@@ -372,6 +372,14 @@ $(document).ready(function() {
|
||||
improve stability it also employs bias correction and trust region optimization techniques.
|
||||
</span>
|
||||
</div>
|
||||
<div class="algorithm continuous off-policy" data-year="201808">
|
||||
<span class="badge">
|
||||
<a href="components/agents/policy_optimization/sac.html">SAC</a>
|
||||
<br>
|
||||
Soft Actor-Critic is an algorithm which optimizes a stochastic policy in an off-policy way.
|
||||
One of the key features of SAC is that it solves a maximum entropy reinforcement learning problem.
|
||||
</span>
|
||||
</div>
|
||||
<div class="algorithm continuous off-policy" data-year="201509">
|
||||
<span class="badge">
|
||||
<a href="components/agents/policy_optimization/ddpg.html">DDPG</a>
|
||||
|
||||
@@ -190,13 +190,15 @@
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.agents.dqn_agent.</code><code class="descname">DQNAgent</code><span class="sig-paren">(</span><em>agent_parameters</em>, <em>parent: Union[LevelManager</em>, <em>CompositeAgent] = None</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/rl_coach/agents/dqn_agent.html#DQNAgent"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.act">
|
||||
<code class="descname">act</code><span class="sig-paren">(</span><span class="sig-paren">)</span> → rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="descname">act</code><span class="sig-paren">(</span><em>action: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List] = None</em><span class="sig-paren">)</span> → rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Given the agents current knowledge, decide on the next action to apply to the environment</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
<tbody valign="top">
|
||||
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
|
||||
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>action</strong> – An action to take, overriding whatever the current policy is</td>
|
||||
</tr>
|
||||
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -267,26 +269,6 @@ for creating the network.</p>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer">
|
||||
<code class="descname">emulate_act_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> → rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Given the agents current knowledge, decide on the next action to apply to the environment
|
||||
:return: an action and a dictionary containing any additional info from the action decision process</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer">
|
||||
<code class="descname">emulate_observe_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> → bool<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>This emulates the observe using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Given a response from the environment, distill the observation from it and store it for later use.
|
||||
The response should be a dictionary containing the performed action, the new observation and measurements,
|
||||
the reward, a game over flag and any additional information necessary.
|
||||
:return:</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.get_predictions">
|
||||
<code class="descname">get_predictions</code><span class="sig-paren">(</span><em>states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.get_predictions" title="Permalink to this definition">¶</a></dt>
|
||||
@@ -342,6 +324,22 @@ This function is called right after each episode is ended.</p>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model">
|
||||
<code class="descname">improve_reward_model</code><span class="sig-paren">(</span><em>epochs: int</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Train a reward model to be used by the doubly-robust estimator</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
<tbody valign="top">
|
||||
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>epochs</strong> – The total number of epochs to use for training a reward model</td>
|
||||
</tr>
|
||||
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">None</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules">
|
||||
<code class="descname">init_environment_dependent_modules</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules" title="Permalink to this definition">¶</a></dt>
|
||||
@@ -450,7 +448,7 @@ given observation</td>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference">
|
||||
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> → Dict[str, numpy.core.multiarray.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> → Dict[str, numpy.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
|
||||
observations together, measurements together, etc.</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
@@ -542,6 +540,14 @@ by val, and by the current phase set in self.phase.</p>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation">
|
||||
<code class="descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on
|
||||
an evaluation dataset, which was collected by another policy(ies).
|
||||
:return: None</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference">
|
||||
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> → Dict[str, numpy.ndarray]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference" title="Permalink to this definition">¶</a></dt>
|
||||
|
||||
@@ -278,7 +278,8 @@ To do so, you should select an environment type and level through the command li
|
||||
</dl>
|
||||
</li>
|
||||
</ol>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">coach</span> <span class="o">-</span><span class="n">p</span> <span class="n">Doom_Basic_BC</span> <span class="o">-</span><span class="n">cp</span><span class="o">=</span><span class="s1">'agent.load_memory_from_file_path=</span><span class="se">\"</span><span class="s1"><experiment dir>/replay_buffer.p</span><span class="se">\"</span><span class="s1">'</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="kn">import</span> <span class="n">PickledReplayBuffer</span>
|
||||
<span class="n">coach</span> <span class="o">-</span><span class="n">p</span> <span class="n">Doom_Basic_BC</span> <span class="o">-</span><span class="n">cp</span><span class="o">=</span><span class="s1">'agent.load_memory_from_file_path=PickledReplayBuffer(</span><span class="se">\"</span><span class="s1"><experiment dir>/replay_buffer.p</span><span class="se">\"</span><span class="s1">'</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
Before Width: | Height: | Size: 51 KiB After Width: | Height: | Size: 59 KiB |
|
After Width: | Height: | Size: 109 KiB |
@@ -21,6 +21,7 @@ A detailed description of those algorithms can be found by navigating to each of
|
||||
imitation/cil
|
||||
policy_optimization/cppo
|
||||
policy_optimization/ddpg
|
||||
policy_optimization/sac
|
||||
other/dfp
|
||||
value_optimization/double_dqn
|
||||
value_optimization/dqn
|
||||
|
||||
@@ -38,6 +38,7 @@ Each update perform the following procedure:
|
||||
.. math:: \text{where} \quad \bar{\rho}_{t} = \min{\left\{c,\rho_t\right\}},\quad \rho_t=\frac{\pi (a_t \mid s_t)}{\mu (a_t \mid s_t)}
|
||||
|
||||
3. **Accumulate gradients:**
|
||||
|
||||
:math:`\bullet` **Policy gradients (with bias correction):**
|
||||
|
||||
.. math:: \hat{g}_t^{policy} & = & \bar{\rho}_{t} \nabla \log \pi (a_t \mid s_t) [Q^{ret}(s_t,a_t) - V(s_t)] \\
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
Soft Actor-Critic
|
||||
============
|
||||
|
||||
**Actions space:** Continuous
|
||||
|
||||
**References:** `Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor <https://arxiv.org/abs/1801.01290>`_
|
||||
|
||||
Network Structure
|
||||
-----------------
|
||||
|
||||
.. image:: /_static/img/design_imgs/sac.png
|
||||
:align: center
|
||||
|
||||
Algorithm Description
|
||||
---------------------
|
||||
|
||||
Choosing an action - Continuous actions
|
||||
+++++++++++++++++++++++++++++++++++++
|
||||
|
||||
The policy network is used in order to predict mean and log std for each action. While training, a sample is taken
|
||||
from a Gaussian distribution with these mean and std values. When testing, the agent can choose deterministically
|
||||
by picking the mean value or sample from a gaussian distribution like in training.
|
||||
|
||||
Training the network
|
||||
++++++++++++++++++++
|
||||
Start by sampling a batch :math:`B` of transitions from the experience replay.
|
||||
|
||||
* To train the **Q network**, use the following targets:
|
||||
|
||||
.. math:: y_t^Q=r(s_t,a_t)+\gamma \cdot V(s_{t+1})
|
||||
|
||||
The state value used in the above target is acquired by running the target state value network.
|
||||
|
||||
* To train the **State Value network**, use the following targets:
|
||||
|
||||
.. math:: y_t^V = \min_{i=1,2}Q_i(s_t,\tilde{a}) - log\pi (\tilde{a} \vert s),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)
|
||||
|
||||
The state value network is trained using a sample-based approximation of the connection between and state value and state
|
||||
action values, The actions used for constructing the target are **not** sampled from the replay buffer, but rather sampled
|
||||
from the current policy.
|
||||
|
||||
* To train the **actor network**, use the following equation:
|
||||
|
||||
.. math:: \nabla_{\theta} J \approx \nabla_{\theta} \frac{1}{\vert B \vert} \sum_{s_t\in B} \left( Q \left(s_t, \tilde{a}_\theta(s_t)\right) - log\pi_{\theta}(\tilde{a}_{\theta}(s_t)\vert s_t) \right),\,\,\,\, \tilde{a} \sim \pi(\cdot \vert s_t)
|
||||
|
||||
After every training step, do a soft update of the V target network's weights from the online networks.
|
||||
|
||||
|
||||
.. autoclass:: rl_coach.agents.soft_actor_critic_agent.SoftActorCriticAlgorithmParameters
|
||||
@@ -198,6 +198,14 @@ The algorithms are ordered by their release date in descending order.
|
||||
improve stability it also employs bias correction and trust region optimization techniques.
|
||||
</span>
|
||||
</div>
|
||||
<div class="algorithm continuous off-policy" data-year="201808">
|
||||
<span class="badge">
|
||||
<a href="components/agents/policy_optimization/sac.html">SAC</a>
|
||||
<br>
|
||||
Soft Actor-Critic is an algorithm which optimizes a stochastic policy in an off-policy way.
|
||||
One of the key features of SAC is that it solves a maximum entropy reinforcement learning problem.
|
||||
</span>
|
||||
</div>
|
||||
<div class="algorithm continuous off-policy" data-year="201509">
|
||||
<span class="badge">
|
||||
<a href="components/agents/policy_optimization/ddpg.html">DDPG</a>
|
||||
|
||||
|
Before Width: | Height: | Size: 51 KiB After Width: | Height: | Size: 59 KiB |
@@ -0,0 +1,321 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Union
|
||||
import copy
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
from rl_coach.agents.agent import Agent
|
||||
from rl_coach.agents.policy_optimization_agent import PolicyOptimizationAgent
|
||||
|
||||
from rl_coach.architectures.head_parameters import SACQHeadParameters,SACPolicyHeadParameters,VHeadParameters
|
||||
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, AgentParameters, EmbedderScheme, MiddlewareScheme
|
||||
from rl_coach.core_types import ActionInfo, EnvironmentSteps, RunPhase
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.spaces import BoxActionSpace
|
||||
|
||||
|
||||
# There are 3 networks in SAC implementation. All have the same topology but parameters are not shared.
|
||||
# The networks are:
|
||||
# 1. State Value Network - SACValueNetwork
|
||||
# 2. Soft Q Value Network - SACCriticNetwork
|
||||
# 3. Policy Network - SACPolicyNetwork - currently supporting only Gaussian Policy
|
||||
|
||||
|
||||
# 1. State Value Network - SACValueNetwork
|
||||
# this is the state value network in SAC.
|
||||
# The network is trained to predict (regression) the state value in the max-entropy settings
|
||||
# The objective to be minimized is given in equation (5) in the paper:
|
||||
#
|
||||
# J(psi)= E_(s~D)[0.5*(V_psi(s)-y(s))^2]
|
||||
# where y(s) = E_(a~pi)[Q_theta(s,a)-log(pi(a|s))]
|
||||
|
||||
|
||||
# Default parameters for value network:
|
||||
# topology :
|
||||
# input embedder : EmbedderScheme.Medium (Dense(256)) , relu activation
|
||||
# middleware : EmbedderScheme.Medium (Dense(256)) , relu activation
|
||||
|
||||
|
||||
class SACValueNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='relu')}
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='relu')
|
||||
self.heads_parameters = [VHeadParameters(initializer='xavier')]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 256
|
||||
self.async_training = False
|
||||
self.learning_rate = 0.0003 # 3e-4 see appendix D in the paper
|
||||
self.create_target_network = True # tau is set in SoftActorCriticAlgorithmParameters.rate_for_copying_weights_to_target
|
||||
|
||||
|
||||
# 2. Soft Q Value Network - SACCriticNetwork
|
||||
# the whole network is built in the SACQHeadParameters. we use empty input embedder and middleware
|
||||
class SACCriticNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Empty)
|
||||
self.heads_parameters = [SACQHeadParameters()] # SACQHeadParameters includes the topology of the head
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 256
|
||||
self.async_training = False
|
||||
self.learning_rate = 0.0003
|
||||
self.create_target_network = False
|
||||
|
||||
|
||||
# 3. policy Network
|
||||
# Default parameters for policy network:
|
||||
# topology :
|
||||
# input embedder : EmbedderScheme.Medium (Dense(256)) , relu activation
|
||||
# middleware : EmbedderScheme = [Dense(256)] , relu activation --> scheme should be overridden in preset
|
||||
class SACPolicyNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='relu')}
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='relu')
|
||||
self.heads_parameters = [SACPolicyHeadParameters()]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 256
|
||||
self.async_training = False
|
||||
self.learning_rate = 0.0003
|
||||
self.create_target_network = False
|
||||
self.l2_regularization = 0 # weight decay regularization. not used in the original paper
|
||||
|
||||
|
||||
# Algorithm Parameters
|
||||
|
||||
class SoftActorCriticAlgorithmParameters(AlgorithmParameters):
|
||||
"""
|
||||
:param num_steps_between_copying_online_weights_to_target: (StepMethod)
|
||||
The number of steps between copying the online network weights to the target network weights.
|
||||
|
||||
:param rate_for_copying_weights_to_target: (float)
|
||||
When copying the online network weights to the target network weights, a soft update will be used, which
|
||||
weight the new online network weights by rate_for_copying_weights_to_target. (Tau as defined in the paper)
|
||||
|
||||
:param use_deterministic_for_evaluation: (bool)
|
||||
If True, during the evaluation phase, action are chosen deterministically according to the policy mean
|
||||
and not sampled from the policy distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(1)
|
||||
self.rate_for_copying_weights_to_target = 0.005
|
||||
self.use_deterministic_for_evaluation = True # evaluate agent using deterministic policy (i.e. take the mean value)
|
||||
|
||||
|
||||
class SoftActorCriticAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=SoftActorCriticAlgorithmParameters(),
|
||||
exploration=AdditiveNoiseParameters(),
|
||||
memory=ExperienceReplayParameters(), # SAC doesnt use episodic related data
|
||||
# network wrappers:
|
||||
networks=OrderedDict([("policy", SACPolicyNetworkParameters()),
|
||||
("q", SACCriticNetworkParameters()),
|
||||
("v", SACValueNetworkParameters())]))
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.agents.soft_actor_critic_agent:SoftActorCriticAgent'
|
||||
|
||||
|
||||
# Soft Actor Critic - https://arxiv.org/abs/1801.01290
|
||||
class SoftActorCriticAgent(PolicyOptimizationAgent):
|
||||
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
||||
super().__init__(agent_parameters, parent)
|
||||
self.last_gradient_update_step_idx = 0
|
||||
|
||||
# register signals to track (in learn_from_batch)
|
||||
self.policy_means = self.register_signal('Policy_mu_avg')
|
||||
self.policy_logsig = self.register_signal('Policy_logsig')
|
||||
self.policy_logprob_sampled = self.register_signal('Policy_logp_sampled')
|
||||
self.policy_grads = self.register_signal('Policy_grads_sumabs')
|
||||
|
||||
self.q1_values = self.register_signal("Q1")
|
||||
self.TD_err1 = self.register_signal("TD err1")
|
||||
self.q2_values = self.register_signal("Q2")
|
||||
self.TD_err2 = self.register_signal("TD err2")
|
||||
self.v_tgt_ns = self.register_signal('V_tgt_ns')
|
||||
self.v_onl_ys = self.register_signal('V_onl_ys')
|
||||
self.action_signal = self.register_signal("actions")
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
#########################################
|
||||
# need to update the following networks:
|
||||
# 1. actor (policy)
|
||||
# 2. state value (v)
|
||||
# 3. critic (q1 and q2)
|
||||
# 4. target network - probably already handled by V
|
||||
|
||||
#########################################
|
||||
# define the networks to be used
|
||||
|
||||
# State Value Network
|
||||
value_network = self.networks['v']
|
||||
value_network_keys = self.ap.network_wrappers['v'].input_embedders_parameters.keys()
|
||||
|
||||
# Critic Network
|
||||
q_network = self.networks['q'].online_network
|
||||
q_head = q_network.output_heads[0]
|
||||
q_network_keys = self.ap.network_wrappers['q'].input_embedders_parameters.keys()
|
||||
|
||||
# Actor (policy) Network
|
||||
policy_network = self.networks['policy'].online_network
|
||||
policy_network_keys = self.ap.network_wrappers['policy'].input_embedders_parameters.keys()
|
||||
|
||||
##########################################
|
||||
# 1. updating the actor - according to (13) in the paper
|
||||
policy_inputs = copy.copy(batch.states(policy_network_keys))
|
||||
policy_results = policy_network.predict(policy_inputs)
|
||||
|
||||
policy_mu, policy_std, sampled_raw_actions, sampled_actions, sampled_actions_logprob, \
|
||||
sampled_actions_logprob_mean = policy_results
|
||||
|
||||
self.policy_means.add_sample(policy_mu)
|
||||
self.policy_logsig.add_sample(policy_std)
|
||||
self.policy_logprob_sampled.add_sample(sampled_actions_logprob_mean)
|
||||
|
||||
# get the state-action values for the replayed states and their corresponding actions from the policy
|
||||
q_inputs = copy.copy(batch.states(q_network_keys))
|
||||
q_inputs['output_0_0'] = sampled_actions
|
||||
log_target = q_network.predict(q_inputs)[0].squeeze()
|
||||
|
||||
# log internal q values
|
||||
q1_vals, q2_vals = q_network.predict(q_inputs, outputs=[q_head.q1_output, q_head.q2_output])
|
||||
self.q1_values.add_sample(q1_vals)
|
||||
self.q2_values.add_sample(q2_vals)
|
||||
|
||||
# calculate the gradients according to (13)
|
||||
# get the gradients of log_prob w.r.t the weights (parameters) - indicated as phi in the paper
|
||||
initial_feed_dict = {policy_network.gradients_weights_ph[5]: np.array(1.0)}
|
||||
dlogp_dphi = policy_network.predict(policy_inputs,
|
||||
outputs=policy_network.weighted_gradients[5],
|
||||
initial_feed_dict=initial_feed_dict)
|
||||
|
||||
# calculate dq_da
|
||||
dq_da = q_network.predict(q_inputs,
|
||||
outputs=q_network.gradients_wrt_inputs[1]['output_0_0'])
|
||||
|
||||
# calculate da_dphi
|
||||
initial_feed_dict = {policy_network.gradients_weights_ph[3]: dq_da}
|
||||
dq_dphi = policy_network.predict(policy_inputs,
|
||||
outputs=policy_network.weighted_gradients[3],
|
||||
initial_feed_dict=initial_feed_dict)
|
||||
|
||||
# now given dlogp_dphi, dq_dphi we need to calculate the policy gradients according to (13)
|
||||
policy_grads = [dlogp_dphi[l] - dq_dphi[l] for l in range(len(dlogp_dphi))]
|
||||
|
||||
# apply the gradients to policy networks
|
||||
policy_network.apply_gradients(policy_grads)
|
||||
grads_sumabs = np.sum([np.sum(np.abs(policy_grads[l])) for l in range(len(policy_grads))])
|
||||
self.policy_grads.add_sample(grads_sumabs)
|
||||
|
||||
##########################################
|
||||
# 2. updating the state value online network weights
|
||||
# done by calculating the targets for the v head according to (5) in the paper
|
||||
# value_targets = log_targets-sampled_actions_logprob
|
||||
value_inputs = copy.copy(batch.states(value_network_keys))
|
||||
value_targets = log_target - sampled_actions_logprob
|
||||
|
||||
self.v_onl_ys.add_sample(value_targets)
|
||||
|
||||
# call value_network apply gradients with this target
|
||||
value_loss = value_network.online_network.train_on_batch(value_inputs, value_targets[:,None])[0]
|
||||
|
||||
##########################################
|
||||
# 3. updating the critic (q networks)
|
||||
# updating q networks according to (7) in the paper
|
||||
|
||||
# define the input to the q network: state has been already updated previously, but now we need
|
||||
# the actions from the batch (and not those sampled by the policy)
|
||||
q_inputs['output_0_0'] = batch.actions(len(batch.actions().shape) == 1)
|
||||
|
||||
# define the targets : scale_reward * reward + (1-terminal)*discount*v_target_next_state
|
||||
# define v_target_next_state
|
||||
value_inputs = copy.copy(batch.next_states(value_network_keys))
|
||||
v_target_next_state = value_network.target_network.predict(value_inputs)
|
||||
self.v_tgt_ns.add_sample(v_target_next_state)
|
||||
# Note: reward is assumed to be rescaled by RewardRescaleFilter in the preset parameters
|
||||
TD_targets = batch.rewards(expand_dims=True) + \
|
||||
(1.0 - batch.game_overs(expand_dims=True)) * self.ap.algorithm.discount * v_target_next_state
|
||||
|
||||
# call critic network update
|
||||
result = q_network.train_on_batch(q_inputs, TD_targets, additional_fetches=[q_head.q1_loss, q_head.q2_loss])
|
||||
total_loss, losses, unclipped_grads = result[:3]
|
||||
q1_loss, q2_loss = result[3]
|
||||
self.TD_err1.add_sample(q1_loss)
|
||||
self.TD_err2.add_sample(q2_loss)
|
||||
|
||||
##########################################
|
||||
# 4. updating the value target network
|
||||
# I just need to set the parameter rate_for_copying_weights_to_target in the agent parameters to be 1-tau
|
||||
# where tau is the hyper parameter as defined in sac original implementation
|
||||
|
||||
return total_loss, losses, unclipped_grads
|
||||
|
||||
def get_prediction(self, states):
|
||||
"""
|
||||
get the mean and stdev of the policy distribution given 'states'
|
||||
:param states: the states for which we need to sample actions from the policy
|
||||
:return: mean and stdev
|
||||
"""
|
||||
tf_input_state = self.prepare_batch_for_inference(states, 'policy')
|
||||
return self.networks['policy'].online_network.predict(tf_input_state)
|
||||
|
||||
def train(self):
|
||||
# since the algorithm works with experience replay buffer (non-episodic),
|
||||
# we cant use the policy optimization train method. we need Agent.train
|
||||
# note that since in Agent.train there is no apply_gradients, we need to do it in learn from batch
|
||||
return Agent.train(self)
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
"""
|
||||
choose_action - chooses the most likely action
|
||||
if 'deterministic' - take the mean of the policy which is the prediction of the policy network.
|
||||
else - use the exploration policy
|
||||
:param curr_state:
|
||||
:return: action wrapped in ActionInfo
|
||||
"""
|
||||
if not isinstance(self.spaces.action, BoxActionSpace):
|
||||
raise ValueError("SAC works only for continuous control problems")
|
||||
# convert to batch so we can run it through the network
|
||||
tf_input_state = self.prepare_batch_for_inference(curr_state, 'policy')
|
||||
# use the online network for prediction
|
||||
policy_network = self.networks['policy'].online_network
|
||||
policy_head = policy_network.output_heads[0]
|
||||
result = policy_network.predict(tf_input_state,
|
||||
outputs=[policy_head.policy_mean, policy_head.actions])
|
||||
action_mean, action_sample = result
|
||||
|
||||
# if using deterministic policy, take the mean values. else, use exploration policy to sample from the pdf
|
||||
if self.phase == RunPhase.TEST and self.ap.algorithm.use_deterministic_for_evaluation:
|
||||
action = action_mean[0]
|
||||
else:
|
||||
action = action_sample[0]
|
||||
|
||||
self.action_signal.add_sample(action)
|
||||
|
||||
action_info = ActionInfo(action=action)
|
||||
return action_info
|
||||
@@ -36,7 +36,6 @@ class HeadParameters(NetworkComponentParameters):
|
||||
return 'rl_coach.architectures.tensorflow_components.heads:' + self.parameterized_class_name
|
||||
|
||||
|
||||
|
||||
class PPOHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
@@ -50,11 +49,12 @@ class PPOHeadParameters(HeadParameters):
|
||||
class VHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='v_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=None):
|
||||
loss_weight: float = 1.0, dense_layer=None, initializer='normalized_columns'):
|
||||
super().__init__(parameterized_class_name="VHead", activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
self.initializer = initializer
|
||||
|
||||
|
||||
class CategoricalQHeadParameters(HeadParameters):
|
||||
@@ -196,3 +196,17 @@ class ACERPolicyHeadParameters(HeadParameters):
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class SACPolicyHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='sac_policy_head_params', dense_layer=None):
|
||||
super().__init__(parameterized_class_name='SACPolicyHead', activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
|
||||
|
||||
class SACQHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='sac_q_head_params', dense_layer=None,
|
||||
layers_sizes: tuple = (256, 256)):
|
||||
super().__init__(parameterized_class_name='SACQHead', activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
self.network_layers_sizes = layers_sizes
|
||||
|
||||
@@ -12,6 +12,8 @@ from .quantile_regression_q_head import QuantileRegressionQHead
|
||||
from .rainbow_q_head import RainbowQHead
|
||||
from .v_head import VHead
|
||||
from .acer_policy_head import ACERPolicyHead
|
||||
from .sac_head import SACPolicyHead
|
||||
from .sac_q_head import SACQHead
|
||||
from .classification_head import ClassificationHead
|
||||
from .cil_head import RegressionHead
|
||||
|
||||
@@ -30,6 +32,8 @@ __all__ = [
|
||||
'RainbowQHead',
|
||||
'VHead',
|
||||
'ACERPolicyHead',
|
||||
'ClassificationHead'
|
||||
'SACPolicyHead',
|
||||
'SACQHead',
|
||||
'ClassificationHead',
|
||||
'RegressionHead'
|
||||
]
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import eps
|
||||
|
||||
LOG_SIG_CAP_MAX = 2
|
||||
LOG_SIG_CAP_MIN = -20
|
||||
|
||||
|
||||
class SACPolicyHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
squash: bool = True, dense_layer=Dense):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.name = 'sac_policy_head'
|
||||
self.return_type = ActionProbabilities
|
||||
self.num_actions = self.spaces.action.shape # continuous actions
|
||||
self.squash = squash # squashing using tanh
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
self.given_raw_actions = tf.placeholder(tf.float32, [None, self.num_actions], name="actions")
|
||||
self.input = [self.given_raw_actions]
|
||||
self.output = []
|
||||
|
||||
# build the network
|
||||
self._build_continuous_net(input_layer, self.spaces.action)
|
||||
|
||||
def _squash_correction(self,actions):
|
||||
'''
|
||||
correct squash operation (in case of bounded actions) according to appendix C in the paper.
|
||||
NOTE : this correction assume the squash is done with tanh.
|
||||
:param actions: unbounded actions
|
||||
:return: the correction to be applied to the log_prob of the actions, assuming tanh squash
|
||||
'''
|
||||
if not self.squash:
|
||||
return 0
|
||||
return tf.reduce_sum(tf.log(1 - tf.tanh(actions) ** 2 + eps), axis=1)
|
||||
|
||||
def _build_continuous_net(self, input_layer, action_space):
|
||||
num_actions = action_space.shape[0]
|
||||
|
||||
self.policy_mu_and_logsig = self.dense_layer(2*num_actions)(input_layer, name='policy_mu_logsig')
|
||||
self.policy_mean = tf.identity(self.policy_mu_and_logsig[..., :num_actions], name='policy_mean')
|
||||
self.policy_log_std = tf.clip_by_value(self.policy_mu_and_logsig[..., num_actions:],
|
||||
LOG_SIG_CAP_MIN, LOG_SIG_CAP_MAX,name='policy_log_std')
|
||||
|
||||
self.output.append(self.policy_mean) # output[0]
|
||||
self.output.append(self.policy_log_std) # output[1]
|
||||
|
||||
# define the distributions for the policy
|
||||
# Tensorflow's multivariate normal distribution supports reparameterization
|
||||
tfd = tf.contrib.distributions
|
||||
self.policy_distribution = tfd.MultivariateNormalDiag(loc=self.policy_mean,
|
||||
scale_diag=tf.exp(self.policy_log_std))
|
||||
|
||||
# define network outputs
|
||||
# note that tensorflow supports reparametrization.
|
||||
# i.e. policy_action_sample is a tensor through which gradients can flow
|
||||
self.raw_actions = self.policy_distribution.sample()
|
||||
|
||||
if self.squash:
|
||||
self.actions = tf.tanh(self.raw_actions)
|
||||
# correct log_prob in case of squash (see appendix C in the paper)
|
||||
squash_correction = self._squash_correction(self.raw_actions)
|
||||
else:
|
||||
self.actions = self.raw_actions
|
||||
squash_correction = 0
|
||||
|
||||
# policy_action_logprob is a tensor through which gradients can flow
|
||||
self.sampled_actions_logprob = self.policy_distribution.log_prob(self.raw_actions) - squash_correction
|
||||
self.sampled_actions_logprob_mean = tf.reduce_mean(self.sampled_actions_logprob)
|
||||
|
||||
self.output.append(self.raw_actions) # output[2] : sampled raw action (before squash)
|
||||
self.output.append(self.actions) # output[3] : squashed (if needed) version of sampled raw_actions
|
||||
self.output.append(self.sampled_actions_logprob) # output[4]: log prob of sampled action (squash corrected)
|
||||
self.output.append(self.sampled_actions_logprob_mean) # output[5]: mean of log prob of sampled actions (squash corrected)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"policy head:"
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = {0})".format(2*self.num_actions),
|
||||
"policy_mu = output[:num_actions], policy_std = output[num_actions:]"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
@@ -0,0 +1,116 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace
|
||||
|
||||
|
||||
class SACQHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.name = 'q_values_head'
|
||||
if isinstance(self.spaces.action, BoxActionSpace):
|
||||
self.num_actions = self.spaces.action.shape # continuous actions
|
||||
else:
|
||||
raise ValueError(
|
||||
'SACQHead does not support action spaces of type: {class_name}'.format(
|
||||
class_name=self.spaces.action.__class__.__name__,
|
||||
)
|
||||
)
|
||||
self.return_type = QActionStateValue
|
||||
# extract the topology from the SACQHeadParameters
|
||||
self.network_layers_sizes = agent_parameters.network_wrappers['q'].heads_parameters[0].network_layers_sizes
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# SAC Q network is basically 2 networks running in parallel on the same input (state , action)
|
||||
# state is the observation fed through the input_layer, action is fed through placeholder to the header
|
||||
# each is calculating q value : q1(s,a) and q2(s,a)
|
||||
# the output of the head is min(q1,q2)
|
||||
self.actions = tf.placeholder(tf.float32, [None, self.num_actions], name="actions")
|
||||
self.target = tf.placeholder(tf.float32, [None, 1], name="q_targets")
|
||||
self.input = [self.actions]
|
||||
self.output = []
|
||||
# Note (1) : in the author's implementation of sac (in rllab) they summarize the embedding of observation and
|
||||
# action (broadcasting the bias) in the first layer of the network.
|
||||
|
||||
# build q1 network head
|
||||
with tf.variable_scope("q1_head"):
|
||||
layer_size = self.network_layers_sizes[0]
|
||||
qi_obs_emb = self.dense_layer(layer_size)(input_layer, activation=self.activation_function)
|
||||
qi_act_emb = self.dense_layer(layer_size)(self.actions, activation=self.activation_function)
|
||||
qi_output = qi_obs_emb + qi_act_emb # merging the inputs by summarizing them (see Note (1))
|
||||
for layer_size in self.network_layers_sizes[1:]:
|
||||
qi_output = self.dense_layer(layer_size)(qi_output, activation=self.activation_function)
|
||||
# the output layer
|
||||
self.q1_output = self.dense_layer(1)(qi_output, name='q1_output')
|
||||
|
||||
# build q2 network head
|
||||
with tf.variable_scope("q2_head"):
|
||||
layer_size = self.network_layers_sizes[0]
|
||||
qi_obs_emb = self.dense_layer(layer_size)(input_layer, activation=self.activation_function)
|
||||
qi_act_emb = self.dense_layer(layer_size)(self.actions, activation=self.activation_function)
|
||||
qi_output = qi_obs_emb + qi_act_emb # merging the inputs by summarizing them (see Note (1))
|
||||
for layer_size in self.network_layers_sizes[1:]:
|
||||
qi_output = self.dense_layer(layer_size)(qi_output, activation=self.activation_function)
|
||||
# the output layer
|
||||
self.q2_output = self.dense_layer(1)(qi_output, name='q2_output')
|
||||
|
||||
# take the minimum as the network's output. this is the log_target (in the original implementation)
|
||||
self.q_output = tf.minimum(self.q1_output, self.q2_output, name='q_output')
|
||||
# the policy gradients
|
||||
# self.q_output_mean = tf.reduce_mean(self.q1_output) # option 1: use q1
|
||||
self.q_output_mean = tf.reduce_mean(self.q_output) # option 2: use min(q1,q2)
|
||||
|
||||
self.output.append(self.q_output)
|
||||
self.output.append(self.q_output_mean)
|
||||
|
||||
# defining the loss
|
||||
self.q1_loss = 0.5*tf.reduce_mean(tf.square(self.q1_output - self.target))
|
||||
self.q2_loss = 0.5*tf.reduce_mean(tf.square(self.q2_output - self.target))
|
||||
# eventually both losses are depends on different parameters so we can sum them up
|
||||
self.loss = self.q1_loss+self.q2_loss
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"q1 output"
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 1)",
|
||||
"q2 output"
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 1)",
|
||||
"min(Q1,Q2)"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from rl_coach.spaces import SpacesDefinition
|
||||
class VHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
dense_layer=Dense, initializer='normalized_columns'):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.name = 'v_values_head'
|
||||
@@ -37,10 +37,15 @@ class VHead(Head):
|
||||
else:
|
||||
self.loss_type = tf.losses.mean_squared_error
|
||||
|
||||
self.initializer = initializer
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# Standard V Network
|
||||
if self.initializer == 'normalized_columns':
|
||||
self.output = self.dense_layer(1)(input_layer, name='output',
|
||||
kernel_initializer=normalized_columns_initializer(1.0))
|
||||
elif self.initializer == 'xavier' or self.initializer is None:
|
||||
self.output = self.dense_layer(1)(input_layer, name='output')
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
from rl_coach.agents.soft_actor_critic_agent import SoftActorCriticAgentParameters
|
||||
from rl_coach.architectures.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
|
||||
# see graph_manager.py for possible schedule parameters
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(3000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(1000)
|
||||
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(10000)
|
||||
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
#########
|
||||
agent_params = SoftActorCriticAgentParameters()
|
||||
# override default parameters:
|
||||
# value (v) networks parameters
|
||||
agent_params.network_wrappers['v'].batch_size = 256
|
||||
agent_params.network_wrappers['v'].learning_rate = 0.0003
|
||||
agent_params.network_wrappers['v'].middleware_parameters.scheme = [Dense(256)]
|
||||
|
||||
# critic (q) network parameters
|
||||
agent_params.network_wrappers['q'].heads_parameters[0].network_layers_sizes = (256, 256)
|
||||
agent_params.network_wrappers['q'].batch_size = 256
|
||||
agent_params.network_wrappers['q'].learning_rate = 0.0003
|
||||
|
||||
# actor (policy) network parameters
|
||||
agent_params.network_wrappers['policy'].batch_size = 256
|
||||
agent_params.network_wrappers['policy'].learning_rate = 0.0003
|
||||
agent_params.network_wrappers['policy'].middleware_parameters.scheme = [Dense(256)]
|
||||
|
||||
# Input Filter
|
||||
# SAC requires reward scaling for Mujoco environments.
|
||||
# according to the paper:
|
||||
# Hopper, Walker-2d, HalfCheetah, Ant - requires scaling of 5
|
||||
# Humanoid - requires scaling of 20
|
||||
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(5))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test = True
|
||||
preset_validation_params.min_reward_threshold = 400
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 2200
|
||||
preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
@@ -85,7 +85,7 @@ extras['all'] = all_deps
|
||||
|
||||
setup(
|
||||
name='rl-coach' if not slim_package else 'rl-coach-slim',
|
||||
version='0.11.2',
|
||||
version='0.12.0',
|
||||
description='Reinforcement Learning Coach enables easy experimentation with state of the art Reinforcement Learning algorithms.',
|
||||
url='https://github.com/NervanaSystems/coach',
|
||||
author='Intel AI Lab',
|
||||
|
||||