1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Batch RL Tutorial (#372)

This commit is contained in:
Gal Leibovich
2019-07-14 18:43:48 +03:00
committed by GitHub
parent b82414138d
commit 19ad2d60a7
40 changed files with 1155 additions and 182 deletions

View File

@@ -217,9 +217,9 @@
<span class="k">class</span> <span class="nc">DDPGCriticNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">),</span>
<span class="s1">&#39;action&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">scheme</span><span class="o">=</span><span class="n">EmbedderScheme</span><span class="o">.</span><span class="n">Shallow</span><span class="p">)}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGVHeadParameters</span><span class="p">()]</span>
@@ -236,11 +236,11 @@
<span class="k">class</span> <span class="nc">DDPGActorNetworkParameters</span><span class="p">(</span><span class="n">NetworkParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">)}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGActorHeadParameters</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_embedders_parameters</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">InputEmbedderParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">middleware_parameters</span> <span class="o">=</span> <span class="n">FCMiddlewareParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">DDPGActorHeadParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer_type</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="bp">self</span><span class="o">.</span><span class="n">adam_optimizer_beta2</span> <span class="o">=</span> <span class="mf">0.999</span>
@@ -292,12 +292,12 @@
<span class="k">class</span> <span class="nc">DDPGAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">DDPGAlgorithmParameters</span><span class="p">(),</span>
<span class="n">exploration</span><span class="o">=</span><span class="n">OUProcessParameters</span><span class="p">(),</span>
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([(</span><span class="s2">&quot;actor&quot;</span><span class="p">,</span> <span class="n">DDPGActorNetworkParameters</span><span class="p">()),</span>
<span class="p">(</span><span class="s2">&quot;critic&quot;</span><span class="p">,</span> <span class="n">DDPGCriticNetworkParameters</span><span class="p">())]))</span>
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([(</span><span class="s2">&quot;actor&quot;</span><span class="p">,</span> <span class="n">DDPGActorNetworkParameters</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)),</span>
<span class="p">(</span><span class="s2">&quot;critic&quot;</span><span class="p">,</span> <span class="n">DDPGCriticNetworkParameters</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">))]))</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
@@ -353,7 +353,9 @@
<span class="c1"># train the critic</span>
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">))</span>
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">,</span> <span class="n">TD_targets</span><span class="p">)</span>
<span class="c1"># also need the inputs for when applying gradients so batchnorm&#39;s update of running mean and stddev will work</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">critic</span><span class="o">.</span><span class="n">train_and_sync_networks</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">,</span> <span class="n">TD_targets</span><span class="p">,</span> <span class="n">use_inputs_for_apply_gradients</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span> <span class="o">=</span> <span class="n">result</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
<span class="c1"># apply the gradients from the critic to the actor</span>
@@ -362,11 +364,12 @@
<span class="n">outputs</span><span class="o">=</span><span class="n">actor</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">weighted_gradients</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="n">initial_feed_dict</span><span class="o">=</span><span class="n">initial_feed_dict</span><span class="p">)</span>
<span class="c1"># also need the inputs for when applying gradients so batchnorm&#39;s update of running mean and stddev will work</span>
<span class="k">if</span> <span class="n">actor</span><span class="o">.</span><span class="n">has_global</span><span class="p">:</span>
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">)))</span>
<span class="n">actor</span><span class="o">.</span><span class="n">update_online_network</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_online_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
<span class="n">actor</span><span class="o">.</span><span class="n">apply_gradients_to_online_network</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">states</span><span class="p">(</span><span class="n">critic_keys</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">unclipped_grads</span>

View File

@@ -307,31 +307,37 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">set_weights</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="o">.</span><span class="n">get_weights</span><span class="p">(),</span> <span class="n">rate</span><span class="p">)</span></div>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_global_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_global_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_global_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_global_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Apply gradients from the online network on the global network</span>
<span class="sd"> :param gradients: optional gradients that will be used instead of teh accumulated gradients</span>
<span class="sd"> :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm&#39;s</span>
<span class="sd"> update ops also requires the inputs)</span>
<span class="sd"> :return:</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">gradients</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">gradients</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">network_parameters</span><span class="o">.</span><span class="n">shared_optimizer</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span></div>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_online_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_online_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_online_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_online_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Apply gradients from the online network on itself</span>
<span class="sd"> :param gradients: optional gradients that will be used instead of teh accumulated gradients</span>
<span class="sd"> :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm&#39;s</span>
<span class="sd"> update ops also requires the inputs)</span>
<span class="sd"> :return:</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">gradients</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">gradients</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span></div>
<div class="viewcode-block" id="NetworkWrapper.train_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">train_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="p">[],</span> <span class="n">importance_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<div class="viewcode-block" id="NetworkWrapper.train_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">train_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="p">[],</span> <span class="n">importance_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">use_inputs_for_apply_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> A generic training function that enables multi-threading training using a global network if necessary.</span>
@@ -340,14 +346,20 @@
<span class="sd"> :param additional_fetches: Any additional tensor the user wants to fetch</span>
<span class="sd"> :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss</span>
<span class="sd"> error of this sample. If it is not given, the samples losses won&#39;t be scaled</span>
<span class="sd"> :param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients</span>
<span class="sd"> (e.g. for incorporating batchnorm update ops)</span>
<span class="sd"> :return: The loss of the training iteration</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="n">additional_fetches</span><span class="p">,</span>
<span class="n">importance_weights</span><span class="o">=</span><span class="n">importance_weights</span><span class="p">,</span> <span class="n">no_accumulation</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="n">reset_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_inputs_for_apply_gradients</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="n">reset_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="n">reset_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span></div>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">reset_gradients</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">reset_gradients</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Applies the gradients accumulated in the online network to the global network or to itself and syncs the</span>
<span class="sd"> networks if necessary</span>
@@ -356,17 +368,22 @@
<span class="sd"> the network. this is useful when the accumulated gradients are overwritten instead</span>
<span class="sd"> if accumulated by the accumulate_gradients function. this allows reducing time</span>
<span class="sd"> complexity for this function by around 10%</span>
<span class="sd"> :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm&#39;s</span>
<span class="sd"> update ops also requires the inputs)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">(</span><span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">reset_gradients</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">reset_accumulated_gradients</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_online_network</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">reset_gradients</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_and_reset_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_and_reset_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">,</span>
<span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">)</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">,</span>
<span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span></div>
<div class="viewcode-block" id="NetworkWrapper.parallel_prediction"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.parallel_prediction">[docs]</a> <span class="k">def</span> <span class="nf">parallel_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">network_input_tuples</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">]):</span>
<span class="sd">&quot;&quot;&quot;</span>

View File

@@ -213,7 +213,7 @@
<span class="n">failed_imports</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot;RoboSchool&quot;</span><span class="p">)</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">rl_coach.gym_extensions.continuous</span> <span class="k">import</span> <span class="n">mujoco</span>
<span class="kn">from</span> <span class="nn">gym_extensions.continuous</span> <span class="k">import</span> <span class="n">mujoco</span>
<span class="k">except</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">rl_coach.logger</span> <span class="k">import</span> <span class="n">failed_imports</span>
<span class="n">failed_imports</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot;GymExtensions&quot;</span><span class="p">)</span>
@@ -575,9 +575,6 @@
<span class="k">else</span><span class="p">:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">&quot;Error: Environment </span><span class="si">{}</span><span class="s2"> does not support human control.&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="p">),</span> <span class="n">crash</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># initialize the state by getting a new state from the environment</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reset_internal_state</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># render</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_rendered</span><span class="p">:</span>
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_rendered_image</span><span class="p">()</span>
@@ -588,7 +585,6 @@
<span class="bp">self</span><span class="o">.</span><span class="n">renderer</span><span class="o">.</span><span class="n">create_screen</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">*</span><span class="n">scale</span><span class="p">,</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">*</span><span class="n">scale</span><span class="p">)</span>
<span class="c1"># the info is only updated after the first step</span>
<span class="bp">self</span><span class="o">.</span><span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">default_action</span><span class="p">)</span><span class="o">.</span><span class="n">next_state</span>
<span class="bp">self</span><span class="o">.</span><span class="n">state_space</span><span class="p">[</span><span class="s1">&#39;measurements&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">VectorObservationSpace</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">env</span><span class="o">.</span><span class="n">spec</span> <span class="ow">and</span> <span class="n">custom_reward_threshold</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>

View File

@@ -247,15 +247,14 @@
<span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">reward</span><span class="p">:</span> <span class="n">RewardType</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">RewardType</span><span class="p">:</span>
<span class="k">if</span> <span class="n">update_internal_state</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">reward</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">reward</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">reward</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="n">reward</span><span class="p">]])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">push</span><span class="p">(</span><span class="n">reward</span><span class="p">)</span>
<span class="n">reward</span> <span class="o">=</span> <span class="p">(</span><span class="n">reward</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">mean</span><span class="p">)</span> <span class="o">/</span> \
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">std</span> <span class="o">+</span> <span class="mf">1e-15</span><span class="p">)</span>
<span class="n">reward</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">reward</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_min</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_max</span><span class="p">)</span>
<span class="k">return</span> <span class="n">reward</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">reward</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">get_filtered_reward_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_reward_space</span><span class="p">:</span> <span class="n">RewardSpace</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">RewardSpace</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">running_rewards_stats</span><span class="o">.</span><span class="n">set_params</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">clip_values</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">clip_min</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_max</span><span class="p">))</span>
<span class="k">return</span> <span class="n">input_reward_space</span>
<span class="k">def</span> <span class="nf">save_state_to_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">checkpoint_prefix</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>

View File

@@ -198,6 +198,8 @@
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">import</span> <span class="nn">ast</span>
<span class="kn">import</span> <span class="nn">pickle</span>
<span class="kn">from</span> <span class="nn">copy</span> <span class="k">import</span> <span class="n">deepcopy</span>
<span class="kn">import</span> <span class="nn">math</span>
@@ -324,14 +326,27 @@
<span class="k">def</span> <span class="nf">shuffle_episodes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Shuffle all the episodes in the replay buffer</span>
<span class="sd"> Shuffle all the complete episodes in the replay buffer, while deleting the last non-complete episode</span>
<span class="sd"> :return:</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">assert_not_frozen</span><span class="p">()</span>
<span class="c1"># unlike the standard usage of the EpisodicExperienceReplay, where we always leave an empty episode after</span>
<span class="c1"># the last full one, so that new transitions will have where to be added, in this case we delibrately remove</span>
<span class="c1"># that empty last episode, as we are about to shuffle the memory, and we don&#39;t want it to be shuffled in</span>
<span class="bp">self</span><span class="o">.</span><span class="n">remove_last_episode</span><span class="p">(</span><span class="n">lock</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">transitions</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="o">.</span><span class="n">transitions</span><span class="p">]</span>
<span class="c1"># create a new Episode for the next transitions to be placed into</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Episode</span><span class="p">(</span><span class="n">n_step</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">n_step</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_length</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">get_shuffled_training_data_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Transition</span><span class="p">]:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.</span>
@@ -384,10 +399,10 @@
<span class="n">granularity</span><span class="p">,</span> <span class="n">size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span>
<span class="k">if</span> <span class="n">granularity</span> <span class="o">==</span> <span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Transitions</span><span class="p">:</span>
<span class="k">while</span> <span class="n">size</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_transitions</span><span class="p">()</span> <span class="o">&gt;</span> <span class="n">size</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">remove_first_episode</span><span class="p">(</span><span class="n">lock</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">granularity</span> <span class="o">==</span> <span class="n">MemoryGranularity</span><span class="o">.</span><span class="n">Episodes</span><span class="p">:</span>
<span class="k">while</span> <span class="bp">self</span><span class="o">.</span><span class="n">length</span><span class="p">()</span> <span class="o">&gt;</span> <span class="n">size</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">remove_first_episode</span><span class="p">(</span><span class="n">lock</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_update_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode</span><span class="p">:</span> <span class="n">Episode</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">episode</span><span class="o">.</span><span class="n">update_transitions_rewards_and_bootstrap_data</span><span class="p">()</span>
@@ -504,31 +519,53 @@
<span class="k">def</span> <span class="nf">_remove_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Remove the episode in the given index (even if it is not complete yet)</span>
<span class="sd"> :param episode_index: the index of the episode to remove</span>
<span class="sd"> Remove either the first or the last index</span>
<span class="sd"> :param episode_index: the index of the episode to remove (either 0 or -1)</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">assert_not_frozen</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">episode_index</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">episode_index</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;_remove_episode only supports removing the first or the last &quot;</span> \
<span class="s2">&quot;episode&quot;</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span> <span class="o">&gt;</span> <span class="n">episode_index</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">episode_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_index</span><span class="p">]</span><span class="o">.</span><span class="n">length</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_length</span> <span class="o">-=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions</span> <span class="o">-=</span> <span class="n">episode_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_num_transitions_in_complete_episodes</span> <span class="o">-=</span> <span class="n">episode_length</span>
<span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[:</span><span class="n">episode_length</span><span class="p">]</span>
<span class="k">if</span> <span class="n">episode_index</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[:</span><span class="n">episode_length</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span> <span class="c1"># episode_index = -1</span>
<span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="o">-</span><span class="n">episode_length</span><span class="p">:]</span>
<span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">[</span><span class="n">episode_index</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">remove_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">remove_first_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Remove the episode in the given index (even if it is not complete yet)</span>
<span class="sd"> :param episode_index: the index of the episode to remove</span>
<span class="sd"> Remove the first episode (even if it is not complete yet)</span>
<span class="sd"> :param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class</span>
<span class="sd"> locks and then calls store with lock = True</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing_and_reading</span><span class="p">()</span>
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing_and_reading</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="n">episode_index</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">remove_last_episode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Remove the last episode (even if it is not complete yet)</span>
<span class="sd"> :param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class</span>
<span class="sd"> locks and then calls store with lock = True</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">lock_writing_and_reading</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_remove_episode</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">lock</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reader_writer_lock</span><span class="o">.</span><span class="n">release_writing_and_reading</span><span class="p">()</span>
<span class="c1"># for API compatibility</span>
<span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">Episode</span><span class="p">]:</span>
@@ -555,15 +592,6 @@
<span class="k">return</span> <span class="n">episode</span>
<span class="c1"># for API compatibility</span>
<span class="k">def</span> <span class="nf">remove</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">episode_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Remove the episode in the given index (even if it is not complete yet)</span>
<span class="sd"> :param episode_index: the index of the episode to remove</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">remove_episode</span><span class="p">(</span><span class="n">episode_index</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">clean</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Clean the memory by removing all the episodes</span>
@@ -629,7 +657,7 @@
<span class="n">transitions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">Transition</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">state</span><span class="p">},</span>
<span class="n">action</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">],</span> <span class="n">reward</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;reward&#39;</span><span class="p">],</span>
<span class="n">action</span><span class="o">=</span><span class="nb">int</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">]),</span> <span class="n">reward</span><span class="o">=</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;reward&#39;</span><span class="p">],</span>
<span class="n">next_state</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;observation&#39;</span><span class="p">:</span> <span class="n">next_state</span><span class="p">},</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">info</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;all_action_probabilities&#39;</span><span class="p">:</span>
<span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span><span class="n">current_transition</span><span class="p">[</span><span class="s1">&#39;all_action_probabilities&#39;</span><span class="p">])}),</span>
@@ -698,7 +726,40 @@
<span class="n">episode_num</span><span class="p">,</span> <span class="n">episode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_episode_for_transition</span><span class="p">(</span><span class="n">transition</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">=</span> <span class="n">episode_num</span>
<span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_transition_id</span> <span class="o">=</span> \
<span class="nb">len</span><span class="p">([</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="p">])</span></div>
<span class="nb">len</span><span class="p">([</span><span class="n">t</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes_from_to</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_training_set_episode_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">e</span><span class="p">])</span>
<span class="k">def</span> <span class="nf">save</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Save the replay buffer contents to a pickle file</span>
<span class="sd"> :param file_path: the path to the file that will be used to store the pickled transitions</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s1">&#39;wb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">file</span><span class="p">:</span>
<span class="n">pickle</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_all_complete_episodes</span><span class="p">(),</span> <span class="n">file</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">load_pickled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Restore the replay buffer contents from a pickle file.</span>
<span class="sd"> The pickle file is assumed to include a list of transitions.</span>
<span class="sd"> :param file_path: The path to a pickle file to restore</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">assert_not_frozen</span><span class="p">()</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">file</span><span class="p">:</span>
<span class="n">episodes</span> <span class="o">=</span> <span class="n">pickle</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">file</span><span class="p">)</span>
<span class="n">num_transitions</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">e</span><span class="o">.</span><span class="n">transitions</span><span class="p">)</span> <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="n">episodes</span><span class="p">])</span>
<span class="k">if</span> <span class="n">num_transitions</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
<span class="n">screen</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;Warning! The number of transition to load into the replay buffer (</span><span class="si">{}</span><span class="s2">) is &quot;</span>
<span class="s2">&quot;bigger than the max size of the replay buffer (</span><span class="si">{}</span><span class="s2">). The excessive transitions will &quot;</span>
<span class="s2">&quot;not be stored.&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">num_transitions</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
<span class="n">progress_bar</span> <span class="o">=</span> <span class="n">ProgressBar</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">episodes</span><span class="p">))</span>
<span class="k">for</span> <span class="n">episode_idx</span><span class="p">,</span> <span class="n">episode</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">episodes</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">store_episode</span><span class="p">(</span><span class="n">episode</span><span class="p">)</span>
<span class="c1"># print progress</span>
<span class="n">progress_bar</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">episode_idx</span><span class="p">)</span>
<span class="n">progress_bar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span></div>
</pre></div>
</div>

View File

@@ -381,15 +381,6 @@
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_transition</span><span class="p">(</span><span class="n">transition_index</span><span class="p">,</span> <span class="n">lock</span><span class="p">)</span>
<span class="c1"># for API compatibility</span>
<span class="k">def</span> <span class="nf">remove</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transition_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Remove the transition in the given index</span>
<span class="sd"> :param transition_index: the index of the transition to remove</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">remove_transition</span><span class="p">(</span><span class="n">transition_index</span><span class="p">,</span> <span class="n">lock</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">clean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Clean the memory by removing all the episodes</span>

View File

@@ -0,0 +1,18 @@
Batch Reinforcement Learning
============================
Coach supports Batch Reinforcement Learning, where learning is based solely on a (fixed) batch of data.
In Batch RL, we are given a dataset of experience, which was collected using some (one or more) deployed policies, and we would
like to use it to learn a better policy than what was used to collect the dataset.
There is no simulator to interact with, and so we cannot collect any new data, meaning we often cannot explore the MDP any further.
To make things even harder, we would also like to use the dataset in order to evaluate the newly learned policy
(using off-policy evaluation), since we do not have a simulator which we can use to evaluate the policy on.
Batch RL is also often beneficial in cases where we just want to separate the inference (data collection) from the
training process of a new policy. This is often the case where we have a system on which we could quite easily deploy a policy
and collect experience data, but cannot easily use that system's setup to online train a new policy (as is often the
case with more standard RL algorithms).
Coach supports (almost) all of the integrated off-policy algorithms with Batch RL.
A lot more details and example usage can be found in the
`tutorial <https://github.com/NervanaSystems/coach/blob/master/tutorials/4.%20Batch%20Reinforcement%20Learning.ipynb>`_.

View File

@@ -7,4 +7,5 @@ Features
algorithms
environments
benchmarks
benchmarks
batch_rl

View File

@@ -544,26 +544,34 @@ multi-process distributed mode. The network wrapper contains functionality for m
between them.</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">
<code class="sig-name descname">apply_gradients_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">reset_gradients=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">apply_gradients_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">reset_gradients=True</em>, <em class="sig-param">additional_inputs=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks" title="Permalink to this definition"></a></dt>
<dd><p>Applies the gradients accumulated in the online network to the global network or to itself and syncs the
networks if necessary</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>reset_gradients</strong> If set to True, the accumulated gradients wont be reset to 0 after applying them to
<dd class="field-odd"><ul class="simple">
<li><p><strong>reset_gradients</strong> If set to True, the accumulated gradients wont be reset to 0 after applying them to
the network. this is useful when the accumulated gradients are overwritten instead
if accumulated by the accumulate_gradients function. this allows reducing time
complexity for this function by around 10%</p>
complexity for this function by around 10%</p></li>
<li><p><strong>additional_inputs</strong> optional additional inputs required for when applying the gradients (e.g. batchnorms
update ops also requires the inputs)</p></li>
</ul>
</dd>
</dl>
</dd></dl>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">
<code class="sig-name descname">apply_gradients_to_global_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_global_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">apply_gradients_to_global_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em>, <em class="sig-param">additional_inputs=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_global_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network" title="Permalink to this definition"></a></dt>
<dd><p>Apply gradients from the online network on the global network</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>gradients</strong> optional gradients that will be used instead of teh accumulated gradients</p>
<dd class="field-odd"><ul class="simple">
<li><p><strong>gradients</strong> optional gradients that will be used instead of teh accumulated gradients</p></li>
<li><p><strong>additional_inputs</strong> optional additional inputs required for when applying the gradients (e.g. batchnorms
update ops also requires the inputs)</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p></p>
@@ -573,8 +581,13 @@ complexity for this function by around 10%</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">
<code class="sig-name descname">apply_gradients_to_online_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network" title="Permalink to this definition"></a></dt>
<dd><p>Apply gradients from the online network on itself</p>
<code class="sig-name descname">apply_gradients_to_online_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em>, <em class="sig-param">additional_inputs=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network" title="Permalink to this definition"></a></dt>
<dd><p>Apply gradients from the online network on itself
:param gradients: optional gradients that will be used instead of teh accumulated gradients
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorms</p>
<blockquote>
<div><p>update ops also requires the inputs)</p>
</div></blockquote>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p></p>
@@ -650,7 +663,7 @@ target_network or global_network) and the second element is the inputs</p>
<dl class="method">
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">
<code class="sig-name descname">train_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">inputs</em>, <em class="sig-param">targets</em>, <em class="sig-param">additional_fetches=[]</em>, <em class="sig-param">importance_weights=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.train_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">train_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">inputs</em>, <em class="sig-param">targets</em>, <em class="sig-param">additional_fetches=[]</em>, <em class="sig-param">importance_weights=None</em>, <em class="sig-param">use_inputs_for_apply_gradients=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.train_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks" title="Permalink to this definition"></a></dt>
<dd><p>A generic training function that enables multi-threading training using a global network if necessary.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
@@ -660,6 +673,8 @@ target_network or global_network) and the second element is the inputs</p>
<li><p><strong>additional_fetches</strong> Any additional tensor the user wants to fetch</p></li>
<li><p><strong>importance_weights</strong> A coefficient for each sample in the batch, which will be used to rescale the loss
error of this sample. If it is not given, the samples losses wont be scaled</p></li>
<li><p><strong>use_inputs_for_apply_gradients</strong> Add the inputs also for when applying gradients
(e.g. for incorporating batchnorm update ops)</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>

261
docs/features/batch_rl.html Normal file
View File

@@ -0,0 +1,261 @@
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Batch Reinforcement Learning &mdash; Reinforcement Learning Coach 0.12.0 documentation</title>
<script type="text/javascript" src="../_static/js/modernizr.min.js"></script>
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
<script type="text/javascript" src="../_static/jquery.js"></script>
<script type="text/javascript" src="../_static/underscore.js"></script>
<script type="text/javascript" src="../_static/doctools.js"></script>
<script type="text/javascript" src="../_static/language_data.js"></script>
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../_static/js/theme.js"></script>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/custom.css" type="text/css" />
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Selecting an Algorithm" href="../selecting_an_algorithm.html" />
<link rel="prev" title="Benchmarks" href="benchmarks.html" />
<link href="../_static/css/custom.css" rel="stylesheet" type="text/css">
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home"> Reinforcement Learning Coach
<img src="../_static/dark_logo.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Intro</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../usage.html">Usage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dist_usage.html">Usage - Distributed Coach</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Features</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="algorithms.html">Algorithms</a></li>
<li class="toctree-l2"><a class="reference internal" href="environments.html">Environments</a></li>
<li class="toctree-l2"><a class="reference internal" href="benchmarks.html">Benchmarks</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Batch Reinforcement Learning</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dashboard.html">Coach Dashboard</a></li>
</ul>
<p class="caption"><span class="caption-text">Design</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../design/control_flow.html">Control Flow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../design/network.html">Network Design</a></li>
<li class="toctree-l1"><a class="reference internal" href="../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
</ul>
<p class="caption"><span class="caption-text">Contributing</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../contributing/add_agent.html">Adding a New Agent</a></li>
<li class="toctree-l1"><a class="reference internal" href="../contributing/add_env.html">Adding a New Environment</a></li>
</ul>
<p class="caption"><span class="caption-text">Components</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../components/agents/index.html">Agents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/architectures/index.html">Architectures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/data_stores/index.html">Data Stores</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/environments/index.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/exploration_policies/index.html">Exploration Policies</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/filters/index.html">Filters</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/memories/index.html">Memories</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/memory_backends/index.html">Memory Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/orchestrators/index.html">Orchestrators</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/core_types.html">Core Types</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/spaces.html">Spaces</a></li>
<li class="toctree-l1"><a class="reference internal" href="../components/additional_parameters.html">Additional Parameters</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">Reinforcement Learning Coach</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html">Docs</a> &raquo;</li>
<li><a href="index.html">Features</a> &raquo;</li>
<li>Batch Reinforcement Learning</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/features/batch_rl.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="section" id="batch-reinforcement-learning">
<h1>Batch Reinforcement Learning<a class="headerlink" href="#batch-reinforcement-learning" title="Permalink to this headline"></a></h1>
<p>Coach supports Batch Reinforcement Learning, where learning is based solely on a (fixed) batch of data.
In Batch RL, we are given a dataset of experience, which was collected using some (one or more) deployed policies, and we would
like to use it to learn a better policy than what was used to collect the dataset.
There is no simulator to interact with, and so we cannot collect any new data, meaning we often cannot explore the MDP any further.
To make things even harder, we would also like to use the dataset in order to evaluate the newly learned policy
(using off-policy evaluation), since we do not have a simulator which we can use to evaluate the policy on.
Batch RL is also often beneficial in cases where we just want to separate the inference (data collection) from the
training process of a new policy. This is often the case where we have a system on which we could quite easily deploy a policy
and collect experience data, but cannot easily use that systems setup to online train a new policy (as is often the
case with more standard RL algorithms).</p>
<p>Coach supports (almost) all of the integrated off-policy algorithms with Batch RL.</p>
<p>A lot more details and example usage can be found in the
<a class="reference external" href="https://github.com/NervanaSystems/coach/blob/master/tutorials/4.%20Batch%20Reinforcement%20Learning.ipynb">tutorial</a>.</p>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../selecting_an_algorithm.html" class="btn btn-neutral float-right" title="Selecting an Algorithm" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="benchmarks.html" class="btn btn-neutral float-left" title="Benchmarks" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2018-2019, Intel AI Lab
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -95,6 +95,7 @@
<li class="toctree-l2"><a class="reference internal" href="algorithms.html">Algorithms</a></li>
<li class="toctree-l2"><a class="reference internal" href="environments.html">Environments</a></li>
<li class="toctree-l2"><a class="reference internal" href="benchmarks.html">Benchmarks</a></li>
<li class="toctree-l2"><a class="reference internal" href="batch_rl.html">Batch Reinforcement Learning</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
@@ -197,6 +198,7 @@
<li class="toctree-l1"><a class="reference internal" href="algorithms.html">Algorithms</a></li>
<li class="toctree-l1"><a class="reference internal" href="environments.html">Environments</a></li>
<li class="toctree-l1"><a class="reference internal" href="benchmarks.html">Benchmarks</a></li>
<li class="toctree-l1"><a class="reference internal" href="batch_rl.html">Batch Reinforcement Learning</a></li>
</ul>
</div>
</div>

View File

@@ -210,6 +210,7 @@ Coach collects statistics from the training process and supports advanced visual
<li class="toctree-l2"><a class="reference internal" href="features/algorithms.html">Algorithms</a></li>
<li class="toctree-l2"><a class="reference internal" href="features/environments.html">Environments</a></li>
<li class="toctree-l2"><a class="reference internal" href="features/benchmarks.html">Benchmarks</a></li>
<li class="toctree-l2"><a class="reference internal" href="features/batch_rl.html">Batch Reinforcement Learning</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="selecting_an_algorithm.html">Selecting an Algorithm</a></li>

Binary file not shown.

File diff suppressed because one or more lines are too long