mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Batch RL Tutorial (#372)
This commit is contained in:
@@ -544,26 +544,34 @@ multi-process distributed mode. The network wrapper contains functionality for m
|
||||
between them.</p>
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">
|
||||
<code class="sig-name descname">apply_gradients_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">reset_gradients=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">apply_gradients_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">reset_gradients=True</em>, <em class="sig-param">additional_inputs=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
networks if necessary</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
<dd class="field-odd"><p><strong>reset_gradients</strong> – If set to True, the accumulated gradients wont be reset to 0 after applying them to
|
||||
<dd class="field-odd"><ul class="simple">
|
||||
<li><p><strong>reset_gradients</strong> – If set to True, the accumulated gradients wont be reset to 0 after applying them to
|
||||
the network. this is useful when the accumulated gradients are overwritten instead
|
||||
if accumulated by the accumulate_gradients function. this allows reducing time
|
||||
complexity for this function by around 10%</p>
|
||||
complexity for this function by around 10%</p></li>
|
||||
<li><p><strong>additional_inputs</strong> – optional additional inputs required for when applying the gradients (e.g. batchnorm’s
|
||||
update ops also requires the inputs)</p></li>
|
||||
</ul>
|
||||
</dd>
|
||||
</dl>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">
|
||||
<code class="sig-name descname">apply_gradients_to_global_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_global_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">apply_gradients_to_global_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em>, <em class="sig-param">additional_inputs=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_global_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Apply gradients from the online network on the global network</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
<dd class="field-odd"><p><strong>gradients</strong> – optional gradients that will be used instead of teh accumulated gradients</p>
|
||||
<dd class="field-odd"><ul class="simple">
|
||||
<li><p><strong>gradients</strong> – optional gradients that will be used instead of teh accumulated gradients</p></li>
|
||||
<li><p><strong>additional_inputs</strong> – optional additional inputs required for when applying the gradients (e.g. batchnorm’s
|
||||
update ops also requires the inputs)</p></li>
|
||||
</ul>
|
||||
</dd>
|
||||
<dt class="field-even">Returns</dt>
|
||||
<dd class="field-even"><p></p>
|
||||
@@ -573,8 +581,13 @@ complexity for this function by around 10%</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">
|
||||
<code class="sig-name descname">apply_gradients_to_online_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Apply gradients from the online network on itself</p>
|
||||
<code class="sig-name descname">apply_gradients_to_online_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em>, <em class="sig-param">additional_inputs=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Apply gradients from the online network on itself
|
||||
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm’s</p>
|
||||
<blockquote>
|
||||
<div><p>update ops also requires the inputs)</p>
|
||||
</div></blockquote>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Returns</dt>
|
||||
<dd class="field-odd"><p></p>
|
||||
@@ -650,7 +663,7 @@ target_network or global_network) and the second element is the inputs</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">
|
||||
<code class="sig-name descname">train_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">inputs</em>, <em class="sig-param">targets</em>, <em class="sig-param">additional_fetches=[]</em>, <em class="sig-param">importance_weights=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.train_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">train_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">inputs</em>, <em class="sig-param">targets</em>, <em class="sig-param">additional_fetches=[]</em>, <em class="sig-param">importance_weights=None</em>, <em class="sig-param">use_inputs_for_apply_gradients=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.train_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>A generic training function that enables multi-threading training using a global network if necessary.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
@@ -660,6 +673,8 @@ target_network or global_network) and the second element is the inputs</p>
|
||||
<li><p><strong>additional_fetches</strong> – Any additional tensor the user wants to fetch</p></li>
|
||||
<li><p><strong>importance_weights</strong> – A coefficient for each sample in the batch, which will be used to rescale the loss
|
||||
error of this sample. If it is not given, the samples losses won’t be scaled</p></li>
|
||||
<li><p><strong>use_inputs_for_apply_gradients</strong> – Add the inputs also for when applying gradients
|
||||
(e.g. for incorporating batchnorm update ops)</p></li>
|
||||
</ul>
|
||||
</dd>
|
||||
<dt class="field-even">Returns</dt>
|
||||
|
||||
Reference in New Issue
Block a user