mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
This commit is contained in:
@@ -193,16 +193,17 @@
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.agent</span> <span class="k">import</span> <span class="n">Agent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">StateType</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span><span class="p">,</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">Batch</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.prioritized_experience_replay</span> <span class="k">import</span> <span class="n">PrioritizedExperienceReplay</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">deepcopy</span>
|
||||
|
||||
<span class="c1">## This is an abstract agent - there is no learn_from_batch method ##</span>
|
||||
|
||||
@@ -229,8 +230,9 @@
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">return</span> <span class="n">actions_q_values</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">))</span>
|
||||
<span class="k">def</span> <span class="nf">get_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">states</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="s1">'main'</span><span class="p">),</span>
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">update_transition_priorities_and_get_weights</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">TD_errors</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="c1"># update errors in prioritized replay buffer</span>
|
||||
@@ -259,10 +261,12 @@
|
||||
<span class="c1"># this is for bootstrapped dqn</span>
|
||||
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span> <span class="o">==</span> <span class="nb">list</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">last_action_values</span>
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># store the q values statistics for logging</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_values</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">)</span>
|
||||
|
||||
<span class="n">actions_q_values</span> <span class="o">=</span> <span class="n">actions_q_values</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">q_value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">actions_q_values</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">q_value_for_action</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">q_value</span><span class="p">)</span>
|
||||
|
||||
@@ -276,6 +280,77 @@
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">"ValueOptimizationAgent is an abstract agent. Not to be used directly."</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">run_off_policy_evaluation</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on</span>
|
||||
<span class="sd"> an evaluation dataset, which was collected by another policy(ies).</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">ope_manager</span>
|
||||
<span class="n">dataset_as_episodes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'get_all_complete_episodes_from_to'</span><span class="p">,</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'get_last_training_set_episode_id'</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'num_complete_episodes'</span><span class="p">)))</span>
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataset_as_episodes</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'train_to_eval_ratio is too high causing the evaluation set to be empty. '</span>
|
||||
<span class="s1">'Consider decreasing its value.'</span><span class="p">)</span>
|
||||
|
||||
<span class="n">ips</span><span class="p">,</span> <span class="n">dm</span><span class="p">,</span> <span class="n">dr</span><span class="p">,</span> <span class="n">seq_dr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ope_manager</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span>
|
||||
<span class="n">dataset_as_episodes</span><span class="o">=</span><span class="n">dataset_as_episodes</span><span class="p">,</span>
|
||||
<span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="n">discount_factor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">discount</span><span class="p">,</span>
|
||||
<span class="n">reward_model</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span>
|
||||
<span class="n">q_network</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="p">,</span>
|
||||
<span class="n">network_keys</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'main'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
|
||||
|
||||
<span class="c1"># get the estimators out to the screen</span>
|
||||
<span class="n">log</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'Epoch'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_epoch</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'IPS'</span><span class="p">]</span> <span class="o">=</span> <span class="n">ips</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'DM'</span><span class="p">]</span> <span class="o">=</span> <span class="n">dm</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'DR'</span><span class="p">]</span> <span class="o">=</span> <span class="n">dr</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'Sequential-DR'</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq_dr</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_dict</span><span class="p">(</span><span class="n">log</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="s1">'Off-Policy Evaluation'</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># get the estimators out to dashboard</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">set_current_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_current_time</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Inverse Propensity Score'</span><span class="p">,</span> <span class="n">ips</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Direct Method Reward'</span><span class="p">,</span> <span class="n">dm</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Doubly Robust'</span><span class="p">,</span> <span class="n">dr</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">agent_logger</span><span class="o">.</span><span class="n">create_signal_value</span><span class="p">(</span><span class="s1">'Sequential Doubly Robust'</span><span class="p">,</span> <span class="n">seq_dr</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_reward_model_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span><span class="p">):</span>
|
||||
<span class="n">network_keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">input_embedders_parameters</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||||
<span class="n">current_rewards_prediction_for_all_actions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
|
||||
<span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">))</span>
|
||||
<span class="n">current_rewards_prediction_for_all_actions</span><span class="p">[</span><span class="nb">range</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">rewards</span><span class="p">()</span>
|
||||
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span>
|
||||
<span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">network_keys</span><span class="p">),</span> <span class="n">current_rewards_prediction_for_all_actions</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">improve_reward_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Train a reward model to be used by the doubly-robust estimator</span>
|
||||
|
||||
<span class="sd"> :param epochs: The total number of epochs to use for training a reward model</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">network_wrappers</span><span class="p">[</span><span class="s1">'reward_model'</span><span class="p">]</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
|
||||
<span class="c1"># this is fitted from the training dataset</span>
|
||||
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
|
||||
<span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">total_transitions_processed</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'get_shuffled_data_generator'</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)):</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
||||
<span class="n">loss</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_reward_model_loss</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
||||
<span class="n">total_transitions_processed</span> <span class="o">+=</span> <span class="n">batch</span><span class="o">.</span><span class="n">size</span>
|
||||
|
||||
<span class="n">log</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'Epoch'</span><span class="p">]</span> <span class="o">=</span> <span class="n">epoch</span>
|
||||
<span class="n">log</span><span class="p">[</span><span class="s1">'loss'</span><span class="p">]</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">/</span> <span class="n">total_transitions_processed</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">log_dict</span><span class="p">(</span><span class="n">log</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="s1">'Training Reward Model'</span><span class="p">)</span>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user