1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Batch RL Tutorial (#372)

This commit is contained in:
Gal Leibovich
2019-07-14 18:43:48 +03:00
committed by GitHub
parent b82414138d
commit 19ad2d60a7
40 changed files with 1155 additions and 182 deletions

View File

@@ -307,31 +307,37 @@
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">set_weights</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="o">.</span><span class="n">get_weights</span><span class="p">(),</span> <span class="n">rate</span><span class="p">)</span></div>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_global_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_global_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_global_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_global_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Apply gradients from the online network on the global network</span>
<span class="sd"> :param gradients: optional gradients that will be used instead of teh accumulated gradients</span>
<span class="sd"> :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm&#39;s</span>
<span class="sd"> update ops also requires the inputs)</span>
<span class="sd"> :return:</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">gradients</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">gradients</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">network_parameters</span><span class="o">.</span><span class="n">shared_optimizer</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span></div>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_online_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_online_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_to_online_network"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_to_online_network</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Apply gradients from the online network on itself</span>
<span class="sd"> :param gradients: optional gradients that will be used instead of teh accumulated gradients</span>
<span class="sd"> :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm&#39;s</span>
<span class="sd"> update ops also requires the inputs)</span>
<span class="sd"> :return:</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">gradients</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">gradients</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">)</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span></div>
<div class="viewcode-block" id="NetworkWrapper.train_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">train_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="p">[],</span> <span class="n">importance_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<div class="viewcode-block" id="NetworkWrapper.train_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">train_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="p">[],</span> <span class="n">importance_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">use_inputs_for_apply_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> A generic training function that enables multi-threading training using a global network if necessary.</span>
@@ -340,14 +346,20 @@
<span class="sd"> :param additional_fetches: Any additional tensor the user wants to fetch</span>
<span class="sd"> :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss</span>
<span class="sd"> error of this sample. If it is not given, the samples losses won&#39;t be scaled</span>
<span class="sd"> :param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients</span>
<span class="sd"> (e.g. for incorporating batchnorm update ops)</span>
<span class="sd"> :return: The loss of the training iteration</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulate_gradients</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_fetches</span><span class="o">=</span><span class="n">additional_fetches</span><span class="p">,</span>
<span class="n">importance_weights</span><span class="o">=</span><span class="n">importance_weights</span><span class="p">,</span> <span class="n">no_accumulation</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="n">reset_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_inputs_for_apply_gradients</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="n">reset_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="n">reset_gradients</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span></div>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">reset_gradients</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<div class="viewcode-block" id="NetworkWrapper.apply_gradients_and_sync_networks"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">[docs]</a> <span class="k">def</span> <span class="nf">apply_gradients_and_sync_networks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">reset_gradients</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">additional_inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Applies the gradients accumulated in the online network to the global network or to itself and syncs the</span>
<span class="sd"> networks if necessary</span>
@@ -356,17 +368,22 @@
<span class="sd"> the network. this is useful when the accumulated gradients are overwritten instead</span>
<span class="sd"> if accumulated by the accumulate_gradients function. this allows reducing time</span>
<span class="sd"> complexity for this function by around 10%</span>
<span class="sd"> :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm&#39;s</span>
<span class="sd"> update ops also requires the inputs)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_network</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">apply_gradients_to_global_network</span><span class="p">(</span><span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">reset_gradients</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">reset_accumulated_gradients</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_online_network</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">reset_gradients</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_and_reset_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_and_reset_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">,</span>
<span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">)</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">online_network</span><span class="o">.</span><span class="n">accumulated_gradients</span><span class="p">,</span>
<span class="n">additional_inputs</span><span class="o">=</span><span class="n">additional_inputs</span><span class="p">)</span></div>
<div class="viewcode-block" id="NetworkWrapper.parallel_prediction"><a class="viewcode-back" href="../../../components/architectures/index.html#rl_coach.architectures.network_wrapper.NetworkWrapper.parallel_prediction">[docs]</a> <span class="k">def</span> <span class="nf">parallel_prediction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">network_input_tuples</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">]):</span>
<span class="sd">&quot;&quot;&quot;</span>