mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
This commit is contained in:
@@ -205,6 +205,7 @@
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">TrainingSteps</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">,</span> <span class="n">GradientClippingMethod</span><span class="p">,</span> <span class="n">RunPhase</span><span class="p">,</span> \
|
||||
<span class="n">SelectedPhaseOnlyDumpFilter</span><span class="p">,</span> <span class="n">MaxDumpFilter</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.filters.filter</span> <span class="k">import</span> <span class="n">NoInputFilter</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">screen</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">Frameworks</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
|
||||
@@ -379,9 +380,6 @@
|
||||
<span class="c1"># distributed agents params</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">share_statistics_between_workers</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
<span class="c1"># intrinsic reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">scale_external_reward_by_intrinsic_reward_value</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="c1"># n-step returns</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">n_step</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span> <span class="c1"># calculate the total return (no bootstrap, by default)</span>
|
||||
|
||||
@@ -470,7 +468,8 @@
|
||||
<span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
|
||||
<span class="n">replace_mse_with_huber_loss</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">create_target_network</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">tensorflow_support</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||||
<span class="n">tensorflow_support</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">softmax_temperature</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param force_cpu:</span>
|
||||
<span class="sd"> Force the neural networks to run on the CPU even if a GPU is available</span>
|
||||
@@ -553,6 +552,8 @@
|
||||
<span class="sd"> online network at will.</span>
|
||||
<span class="sd"> :param tensorflow_support:</span>
|
||||
<span class="sd"> A flag which specifies if the network is supported by the TensorFlow framework.</span>
|
||||
<span class="sd"> :param softmax_temperature:</span>
|
||||
<span class="sd"> If a softmax is present in the network head output, use this temperature</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">framework</span> <span class="o">=</span> <span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span>
|
||||
@@ -583,16 +584,19 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="n">heads_parameters</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_separate_networks_per_head</span> <span class="o">=</span> <span class="n">use_separate_networks_per_head</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="n">optimizer_type</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="n">replace_mse_with_huber_loss</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="n">create_target_network</span>
|
||||
|
||||
<span class="c1"># Framework support</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tensorflow_support</span> <span class="o">=</span> <span class="n">tensorflow_support</span>
|
||||
|
||||
<span class="c1"># Hyper-Parameter values</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_epsilon</span> <span class="o">=</span> <span class="n">optimizer_epsilon</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta1</span> <span class="o">=</span> <span class="n">adam_optimizer_beta1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta2</span> <span class="o">=</span> <span class="n">adam_optimizer_beta2</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rms_prop_optimizer_decay</span> <span class="o">=</span> <span class="n">rms_prop_optimizer_decay</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">replace_mse_with_huber_loss</span> <span class="o">=</span> <span class="n">replace_mse_with_huber_loss</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">create_target_network</span> <span class="o">=</span> <span class="n">create_target_network</span>
|
||||
|
||||
<span class="c1"># Framework support</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tensorflow_support</span> <span class="o">=</span> <span class="n">tensorflow_support</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">softmax_temperature</span> <span class="o">=</span> <span class="n">softmax_temperature</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NetworkComponentParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
|
||||
@@ -723,6 +727,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_a_highest_level_agent</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_a_lowest_level_agent</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">task_parameters</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">is_batch_rl_training</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
@@ -730,18 +735,22 @@
|
||||
|
||||
|
||||
<div class="viewcode-block" id="TaskParameters"><a class="viewcode-back" href="../../components/additional_parameters.html#rl_coach.base_parameters.TaskParameters">[docs]</a><span class="k">class</span> <span class="nc">TaskParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">framework_type</span><span class="p">:</span> <span class="n">Frameworks</span><span class="o">=</span><span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">use_cpu</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">framework_type</span><span class="p">:</span> <span class="n">Frameworks</span><span class="o">=</span><span class="n">Frameworks</span><span class="o">.</span><span class="n">tensorflow</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">use_cpu</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">experiment_path</span><span class="o">=</span><span class="s1">'/tmp'</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_restore_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">export_onnx_graph</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">apply_stop_condition</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">num_gpu</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
||||
<span class="n">checkpoint_restore_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">export_onnx_graph</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">apply_stop_condition</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">num_gpu</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param framework_type: deep learning framework type. currently only tensorflow is supported</span>
|
||||
<span class="sd"> :param evaluate_only: the task will be used only for evaluating the model</span>
|
||||
<span class="sd"> :param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.</span>
|
||||
<span class="sd"> A value of 0 means that task will be evaluated for an infinite number of steps.</span>
|
||||
<span class="sd"> :param use_cpu: use the cpu for this task</span>
|
||||
<span class="sd"> :param experiment_path: the path to the directory which will store all the experiment outputs</span>
|
||||
<span class="sd"> :param seed: a seed to use for the random numbers generator</span>
|
||||
<span class="sd"> :param checkpoint_save_secs: the number of seconds between each checkpoint saving</span>
|
||||
<span class="sd"> :param checkpoint_restore_dir: the directory to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_restore_dir:</span>
|
||||
<span class="sd"> [DEPECRATED - will be removed in one of the next releases - switch to checkpoint_restore_path]</span>
|
||||
<span class="sd"> the dir to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_restore_path: the path to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_save_dir: the directory to store the checkpoints in</span>
|
||||
<span class="sd"> :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved</span>
|
||||
<span class="sd"> :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate</span>
|
||||
@@ -753,7 +762,13 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_cpu</span> <span class="o">=</span> <span class="n">use_cpu</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">experiment_path</span> <span class="o">=</span> <span class="n">experiment_path</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_save_secs</span> <span class="o">=</span> <span class="n">checkpoint_save_secs</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_restore_dir</span> <span class="o">=</span> <span class="n">checkpoint_restore_dir</span>
|
||||
<span class="k">if</span> <span class="n">checkpoint_restore_dir</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">'TaskParameters.checkpoint_restore_dir is DEPECRATED and will be removed in one of the next '</span>
|
||||
<span class="s1">'releases. Please switch to using TaskParameters.checkpoint_restore_path, with your '</span>
|
||||
<span class="s1">'directory path. '</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_restore_path</span> <span class="o">=</span> <span class="n">checkpoint_restore_dir</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_restore_path</span> <span class="o">=</span> <span class="n">checkpoint_restore_path</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_save_dir</span> <span class="o">=</span> <span class="n">checkpoint_save_dir</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">seed</span> <span class="o">=</span> <span class="n">seed</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">export_onnx_graph</span> <span class="o">=</span> <span class="n">export_onnx_graph</span>
|
||||
@@ -763,13 +778,14 @@
|
||||
|
||||
<div class="viewcode-block" id="DistributedTaskParameters"><a class="viewcode-back" href="../../components/additional_parameters.html#rl_coach.base_parameters.DistributedTaskParameters">[docs]</a><span class="k">class</span> <span class="nc">DistributedTaskParameters</span><span class="p">(</span><span class="n">TaskParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">framework_type</span><span class="p">:</span> <span class="n">Frameworks</span><span class="p">,</span> <span class="n">parameters_server_hosts</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">worker_hosts</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">job_type</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||||
<span class="n">task_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">num_tasks</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">task_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_tasks</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">num_training_tasks</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">use_cpu</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">experiment_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dnd</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">shared_memory_scratchpad</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_restore_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">shared_memory_scratchpad</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_restore_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">export_onnx_graph</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">apply_stop_condition</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param framework_type: deep learning framework type. currently only tensorflow is supported</span>
|
||||
<span class="sd"> :param evaluate_only: the task will be used only for evaluating the model</span>
|
||||
<span class="sd"> :param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.</span>
|
||||
<span class="sd"> A value of 0 means that task will be evaluated for an infinite number of steps.</span>
|
||||
<span class="sd"> :param parameters_server_hosts: comma-separated list of hostname:port pairs to which the parameter servers are</span>
|
||||
<span class="sd"> assigned</span>
|
||||
<span class="sd"> :param worker_hosts: comma-separated list of hostname:port pairs to which the workers are assigned</span>
|
||||
@@ -782,7 +798,7 @@
|
||||
<span class="sd"> :param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.</span>
|
||||
<span class="sd"> :param seed: a seed to use for the random numbers generator</span>
|
||||
<span class="sd"> :param checkpoint_save_secs: the number of seconds between each checkpoint saving</span>
|
||||
<span class="sd"> :param checkpoint_restore_dir: the directory to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_restore_path: the path to restore the checkpoints from</span>
|
||||
<span class="sd"> :param checkpoint_save_dir: the directory to store the checkpoints in</span>
|
||||
<span class="sd"> :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved</span>
|
||||
<span class="sd"> :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate</span>
|
||||
@@ -790,7 +806,7 @@
|
||||
<span class="sd"> """</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">framework_type</span><span class="o">=</span><span class="n">framework_type</span><span class="p">,</span> <span class="n">evaluate_only</span><span class="o">=</span><span class="n">evaluate_only</span><span class="p">,</span> <span class="n">use_cpu</span><span class="o">=</span><span class="n">use_cpu</span><span class="p">,</span>
|
||||
<span class="n">experiment_path</span><span class="o">=</span><span class="n">experiment_path</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span> <span class="n">checkpoint_save_secs</span><span class="o">=</span><span class="n">checkpoint_save_secs</span><span class="p">,</span>
|
||||
<span class="n">checkpoint_restore_dir</span><span class="o">=</span><span class="n">checkpoint_restore_dir</span><span class="p">,</span> <span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="n">checkpoint_save_dir</span><span class="p">,</span>
|
||||
<span class="n">checkpoint_restore_path</span><span class="o">=</span><span class="n">checkpoint_restore_path</span><span class="p">,</span> <span class="n">checkpoint_save_dir</span><span class="o">=</span><span class="n">checkpoint_save_dir</span><span class="p">,</span>
|
||||
<span class="n">export_onnx_graph</span><span class="o">=</span><span class="n">export_onnx_graph</span><span class="p">,</span> <span class="n">apply_stop_condition</span><span class="o">=</span><span class="n">apply_stop_condition</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">parameters_server_hosts</span> <span class="o">=</span> <span class="n">parameters_server_hosts</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">worker_hosts</span> <span class="o">=</span> <span class="n">worker_hosts</span>
|
||||
|
||||
Reference in New Issue
Block a user