mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Batch RL Tutorial (#372)
This commit is contained in:
@@ -217,9 +217,9 @@
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DDPGCriticNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">),</span>
|
||||
<span class="s1">'action'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">EmbedderScheme</span><span class="o">.</span><span class="n">Shallow</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGVHeadParameters</span><span class="p">()]</span>
|
||||
@@ -236,11 +236,11 @@
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DDPGActorNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGActorHeadParameters</span><span class="p">()]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)}</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGActorHeadParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta2</span> <span class="o">=</span> <span class="mf">0.999</span>
|
||||
@@ -292,12 +292,12 @@
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">DDPGAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">DDPGAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">OUProcessParameters</span><span class="p">(),</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([(</span><span class="s2">"actor"</span><span class="p">,</span> <span class="n">DDPGActorNetworkParameters</span><span class="p">()),</span>
|
||||
<span class="p">(</span><span class="s2">"critic"</span><span class="p">,</span> <span class="n">DDPGCriticNetworkParameters</span><span class="p">())]))</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([(</span><span class="s2">"actor"</span><span class="p">,</span> <span class="n">DDPGActorNetworkParameters</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="s2">"critic"</span><span class="p">,</span> <span class="n">DDPGCriticNetworkParameters</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">))]))</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
@@ -353,7 +353,9 @@
|
||||
<span class="c1"># train the critic</span>
|
||||
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
|
||||
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">'action'</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">,</span> <span class="n">TD_targets</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">,</span> <span class="n">TD_targets</span><span class="p">,</span> <span class="n">use_inputs_for_apply_gradients</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># apply the gradients from the critic to the actor</span>
|
||||
@@ -362,11 +364,12 @@
|
||||
<span class="n">outputs</span><span class="o">=</span><span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">weighted_gradients</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="n">initial_feed_dict</span><span class="o">=</span><span class="n">initial_feed_dict</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work</span>
|
||||
<span class="k">if</span> <span class="n">actor</span><span class="o">.</span><span class="n">has_global</span><span class="p">:</span>
|
||||
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
|
||||
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">)))</span>
|
||||
<span class="n">actor</span><span class="o">.</span><span class="n">update_online_network</span><span class="p">()</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_online_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
|
||||
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_online_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">)))</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>
|
||||
|
||||
|
||||
@@ -307,31 +307,37 @@
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">set_weights</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="o">.</span><span class="n">get_weights</span><span class="p">(),</span> <span class="n">rate</span><span class="p">)</span></div>
|
||||
|
||||
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_global_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_global_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_global_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_global_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Apply gradients from the online network on the global network</span>
|
||||
|
||||
<span class="sd"> :param gradients: optional gradients that will be used instead of teh accumulated gradients</span>
|
||||
<span class="sd"> :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's</span>
|
||||
<span class="sd"> update ops also requires the inputs)</span>
|
||||
<span class="sd"> :return:</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">gradients</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">gradients</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">network_parameters</span><span class="o">.</span><span class="n">shared_optimizer</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span></div>
|
||||
|
||||
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_online_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_online_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_online_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_online_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Apply gradients from the online network on itself</span>
|
||||
<span class="sd"> :param gradients: optional gradients that will be used instead of teh accumulated gradients</span>
|
||||
<span class="sd"> :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's</span>
|
||||
<span class="sd"> update ops also requires the inputs)</span>
|
||||
|
||||
<span class="sd"> :return:</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">gradients</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">gradients</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span></div>
|
||||
|
||||
<div class="viewcode-block" id="NetworkWrapper.train_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">train_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="p">[],</span> <span class="n">importance_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<div class="viewcode-block" id="NetworkWrapper.train_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">train_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="p">[],</span> <span class="n">importance_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">use_inputs_for_apply_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> A generic training function that enables multi-threading training using a global network if necessary.</span>
|
||||
|
||||
@@ -340,14 +346,20 @@
|
||||
<span class="sd"> :param additional_fetches: Any additional tensor the user wants to fetch</span>
|
||||
<span class="sd"> :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss</span>
|
||||
<span class="sd"> error of this sample. If it is not given, the samples losses won't be scaled</span>
|
||||
<span class="sd"> :param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients</span>
|
||||
<span class="sd"> (e.g. for incorporating batchnorm update ops)</span>
|
||||
<span class="sd"> :return: The loss of the training iteration</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="n">additional_fetches</span><span class="p">,</span>
|
||||
<span class="n">importance_weights</span><span class="o">=</span><span class="n">importance_weights</span><span class="p">,</span> <span class="n">no_accumulation</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="n">reset_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">use_inputs_for_apply_gradients</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="n">reset_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="n">reset_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">result</span></div>
|
||||
|
||||
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">reset_gradients</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||||
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">reset_gradients</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Applies the gradients accumulated in the online network to the global network or to itself and syncs the</span>
|
||||
<span class="sd"> networks if necessary</span>
|
||||
@@ -356,17 +368,22 @@
|
||||
<span class="sd"> the network. this is useful when the accumulated gradients are overwritten instead</span>
|
||||
<span class="sd"> if accumulated by the accumulate_gradients function. this allows reducing time</span>
|
||||
<span class="sd"> complexity for this function by around 10%</span>
|
||||
<span class="sd"> :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's</span>
|
||||
<span class="sd"> update ops also requires the inputs)</span>
|
||||
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">(</span><span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">reset_gradients</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">reset_accumulated_gradients</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_online_network</span><span class="p">()</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">reset_gradients</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_and_reset_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_and_reset_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">,</span>
|
||||
<span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">)</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">,</span>
|
||||
<span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span></div>
|
||||
|
||||
<div class="viewcode-block" id="NetworkWrapper.parallel_prediction"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.parallel_prediction">[docs]</a> <span class="k">def</span> <span class="nf">parallel_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">network_input_tuples</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">]):</span>
|
||||
<span class="sd">"""</span>
|
||||
|
||||
@@ -213,7 +213,7 @@
|
||||
<span class="n">failed_imports</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">"RoboSchool"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.gym_extensions.continuous</span> <span class="k">import</span> <span class="n">mujoco</span>
|
||||
<span class="kn">from</span> <span class="nn">gym_extensions.continuous</span> <span class="k">import</span> <span class="n">mujoco</span>
|
||||
<span class="k">except</span><span class="p">:</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">failed_imports</span>
|
||||
<span class="n">failed_imports</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">"GymExtensions"</span><span class="p">)</span>
|
||||
@@ -575,9 +575,6 @@
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">"Error: Environment </span><span class="si">{}</span><span class="s2"> does not support human control."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="p">),</span> <span class="n">crash</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># initialize the state by getting a new state from the environment</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reset_internal_state</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># render</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_rendered</span><span class="p">:</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_rendered_image</span><span class="p">()</span>
|
||||
@@ -588,7 +585,6 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">renderer</span><span class="o">.</span><span class="n">create_screen</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">*</span><span class="n">scale</span><span class="p">,</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">*</span><span class="n">scale</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># the info is only updated after the first step</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">default_action</span><span class="p">)</span><span class="o">.</span><span class="n">next_state</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">state_space</span><span class="p">[</span><span class="s1">'measurements'</span><span class="p">]</span> <span class="o">=</span> <span class="n">VectorObservationSpace</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">spec</span> <span class="ow">and</span> <span class="n">custom_reward_threshold</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
|
||||
@@ -247,15 +247,14 @@
|
||||
|
||||
<span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">reward</span><span class="p">:</span> <span class="n">RewardType</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">RewardType</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">update_internal_state</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">reward</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">reward</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o"><</span> <span class="mi">2</span><span class="p">:</span>
|
||||
<span class="n">reward</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="n">reward</span><span class="p">]])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">push</span><span class="p">(</span><span class="n">reward</span><span class="p">)</span>
|
||||
|
||||
<span class="n">reward</span> <span class="o">=</span> <span class="p">(</span><span class="n">reward</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">mean</span><span class="p">)</span> <span class="o">/</span> \
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">std</span> <span class="o">+</span> <span class="mf">1e-15</span><span class="p">)</span>
|
||||
<span class="n">reward</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">reward</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_min</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_max</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">reward</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">reward</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_filtered_reward_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_reward_space</span><span class="p">:</span> <span class="n">RewardSpace</span><span class="p">)</span> <span class="o">-></span> <span class="n">RewardSpace</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">set_params</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">clip_values</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">clip_min</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_max</span><span class="p">))</span>
|
||||
<span class="k">return</span> <span class="n">input_reward_space</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">save_state_to_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">checkpoint_prefix</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
|
||||
|
||||
@@ -198,6 +198,8 @@
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="kn">import</span> <span class="nn">ast</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">pickle</span>
|
||||
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">deepcopy</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">math</span>
|
||||
@@ -324,14 +326,27 @@
|
||||
|
||||
<span class="k">def</span> <span class="nf">shuffle_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Shuffle all the episodes in the replay buffer</span>
|
||||
<span class="sd"> Shuffle all the complete episodes in the replay buffer, while deleting the last non-complete episode</span>
|
||||
<span class="sd"> :return:</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">assert_not_frozen</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># unlike the standard usage of the EpisodicExperienceReplay, where we always leave an empty episode after</span>
|
||||
<span class="c1"># the last full one, so that new transitions will have where to be added, in this case we delibrately remove</span>
|
||||
<span class="c1"># that empty last episode, as we are about to shuffle the memory, and we don't want it to be shuffled in</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">remove_last_episode</span><span class="p">(</span><span class="n">lock</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">transitions</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="o">.</span><span class="n">transitions</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># create a new Episode for the next transitions to be placed into</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Episode</span><span class="p">(</span><span class="n">n_step</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">n_step</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_length</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_shuffled_training_data_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.</span>
|
||||
@@ -384,10 +399,10 @@
|
||||
<span class="n">granularity</span><span class="p">,</span> <span class="n">size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span>
|
||||
<span class="k">if</span> <span class="n">granularity</span> <span class="o">==</span> <span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">:</span>
|
||||
<span class="k">while</span> <span class="n">size</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_transitions</span><span class="p">()</span> <span class="o">></span> <span class="n">size</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">remove_first_episode</span><span class="p">(</span><span class="n">lock</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="n">granularity</span> <span class="o">==</span> <span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Episodes</span><span class="p">:</span>
|
||||
<span class="k">while</span> <span class="bp">self</span><span class="o">.</span><span class="n">length</span><span class="p">()</span> <span class="o">></span> <span class="n">size</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">remove_first_episode</span><span class="p">(</span><span class="n">lock</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_update_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode</span><span class="p">:</span> <span class="n">Episode</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">episode</span><span class="o">.</span><span class="n">update_transitions_rewards_and_bootstrap_data</span><span class="p">()</span>
|
||||
@@ -504,31 +519,53 @@
|
||||
|
||||
<span class="k">def</span> <span class="nf">_remove_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Remove the episode in the given index (even if it is not complete yet)</span>
|
||||
<span class="sd"> :param episode_index: the index of the episode to remove</span>
|
||||
<span class="sd"> Remove either the first or the last index</span>
|
||||
<span class="sd"> :param episode_index: the index of the episode to remove (either 0 or -1)</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">assert_not_frozen</span><span class="p">()</span>
|
||||
<span class="k">assert</span> <span class="n">episode_index</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">episode_index</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="s2">"_remove_episode only supports removing the first or the last "</span> \
|
||||
<span class="s2">"episode"</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span> <span class="o">></span> <span class="n">episode_index</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">episode_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_index</span><span class="p">]</span><span class="o">.</span><span class="n">length</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_length</span> <span class="o">-=</span> <span class="mi">1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions</span> <span class="o">-=</span> <span class="n">episode_length</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions_in_complete_episodes</span> <span class="o">-=</span> <span class="n">episode_length</span>
|
||||
<span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[:</span><span class="n">episode_length</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">episode_index</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[:</span><span class="n">episode_length</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span> <span class="c1"># episode_index = -1</span>
|
||||
<span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="o">-</span><span class="n">episode_length</span><span class="p">:]</span>
|
||||
<span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_index</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">remove_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">def</span> <span class="nf">remove_first_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Remove the episode in the given index (even if it is not complete yet)</span>
|
||||
<span class="sd"> :param episode_index: the index of the episode to remove</span>
|
||||
<span class="sd"> Remove the first episode (even if it is not complete yet)</span>
|
||||
<span class="sd"> :param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class</span>
|
||||
<span class="sd"> locks and then calls store with lock = True</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing_and_reading</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="n">episode_index</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
|
||||
<span class="k">def</span> <span class="nf">remove_last_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Remove the last episode (even if it is not complete yet)</span>
|
||||
<span class="sd"> :param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class</span>
|
||||
<span class="sd"> locks and then calls store with lock = True</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># for API compatibility</span>
|
||||
<span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
|
||||
@@ -555,15 +592,6 @@
|
||||
|
||||
<span class="k">return</span> <span class="n">episode</span>
|
||||
|
||||
<span class="c1"># for API compatibility</span>
|
||||
<span class="k">def</span> <span class="nf">remove</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Remove the episode in the given index (even if it is not complete yet)</span>
|
||||
<span class="sd"> :param episode_index: the index of the episode to remove</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">remove_episode</span><span class="p">(</span><span class="n">episode_index</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">clean</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Clean the memory by removing all the episodes</span>
|
||||
@@ -629,7 +657,7 @@
|
||||
|
||||
<span class="n">transitions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||||
<span class="n">Transition</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">state</span><span class="p">},</span>
|
||||
<span class="n">action</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'action'</span><span class="p">],</span> <span class="n">reward</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'reward'</span><span class="p">],</span>
|
||||
<span class="n">action</span><span class="o">=</span><span class="nb">int</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'action'</span><span class="p">]),</span> <span class="n">reward</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'reward'</span><span class="p">],</span>
|
||||
<span class="n">next_state</span><span class="o">=</span><span class="p">{</span><span class="s1">'observation'</span><span class="p">:</span> <span class="n">next_state</span><span class="p">},</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">info</span><span class="o">=</span><span class="p">{</span><span class="s1">'all_action_probabilities'</span><span class="p">:</span>
|
||||
<span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">'all_action_probabilities'</span><span class="p">])}),</span>
|
||||
@@ -698,7 +726,40 @@
|
||||
<span class="n">episode_num</span><span class="p">,</span> <span class="n">episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_episode_for_transition</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">=</span> <span class="n">episode_num</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="o">=</span> \
|
||||
<span class="nb">len</span><span class="p">([</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="p">])</span></div>
|
||||
<span class="nb">len</span><span class="p">([</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="p">])</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">save</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Save the replay buffer contents to a pickle file</span>
|
||||
<span class="sd"> :param file_path: the path to the file that will be used to store the pickled transitions</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s1">'wb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">file</span><span class="p">:</span>
|
||||
<span class="n">pickle</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes</span><span class="p">(),</span> <span class="n">file</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">load_pickled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Restore the replay buffer contents from a pickle file.</span>
|
||||
<span class="sd"> The pickle file is assumed to include a list of transitions.</span>
|
||||
<span class="sd"> :param file_path: The path to a pickle file to restore</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">assert_not_frozen</span><span class="p">()</span>
|
||||
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s1">'rb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">file</span><span class="p">:</span>
|
||||
<span class="n">episodes</span> <span class="o">=</span> <span class="n">pickle</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">file</span><span class="p">)</span>
|
||||
<span class="n">num_transitions</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">e</span><span class="o">.</span><span class="n">transitions</span><span class="p">)</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="n">episodes</span><span class="p">])</span>
|
||||
<span class="k">if</span> <span class="n">num_transitions</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Warning! The number of transition to load into the replay buffer (</span><span class="si">{}</span><span class="s2">) is "</span>
|
||||
<span class="s2">"bigger than the max size of the replay buffer (</span><span class="si">{}</span><span class="s2">). The excessive transitions will "</span>
|
||||
<span class="s2">"not be stored."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">num_transitions</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
|
||||
|
||||
<span class="n">progress_bar</span> <span class="o">=</span> <span class="n">ProgressBar</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episodes</span><span class="p">))</span>
|
||||
<span class="k">for</span> <span class="n">episode_idx</span><span class="p">,</span> <span class="n">episode</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">episodes</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">store_episode</span><span class="p">(</span><span class="n">episode</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># print progress</span>
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">episode_idx</span><span class="p">)</span>
|
||||
|
||||
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span></div>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -381,15 +381,6 @@
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_transition</span><span class="p">(</span><span class="n">transition_index</span><span class="p">,</span> <span class="n">lock</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># for API compatibility</span>
|
||||
<span class="k">def</span> <span class="nf">remove</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Remove the transition in the given index</span>
|
||||
<span class="sd"> :param transition_index: the index of the transition to remove</span>
|
||||
<span class="sd"> :return: None</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">remove_transition</span><span class="p">(</span><span class="n">transition_index</span><span class="p">,</span> <span class="n">lock</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">clean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Clean the memory by removing all the episodes</span>
|
||||
|
||||
Reference in New Issue
Block a user