mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
TD3 (#338)
This commit is contained in:
@@ -196,7 +196,7 @@ own components under a dedicated directory. For example, tensorflow components w
|
||||
parts that are implemented using TensorFlow.</p>
|
||||
<dl class="class">
|
||||
<dt id="rl_coach.base_parameters.NetworkParameters">
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.base_parameters.</code><code class="descname">NetworkParameters</code><span class="sig-paren">(</span><em>force_cpu=False</em>, <em>async_training=False</em>, <em>shared_optimizer=True</em>, <em>scale_down_gradients_by_number_of_workers_for_sync_training=True</em>, <em>clip_gradients=None</em>, <em>gradients_clipping_method=<GradientClippingMethod.ClipByGlobalNorm: 0></em>, <em>l2_regularization=0</em>, <em>learning_rate=0.00025</em>, <em>learning_rate_decay_rate=0</em>, <em>learning_rate_decay_steps=0</em>, <em>input_embedders_parameters={}</em>, <em>embedding_merger_type=<EmbeddingMergerType.Concat: 0></em>, <em>middleware_parameters=None</em>, <em>heads_parameters=[]</em>, <em>use_separate_networks_per_head=False</em>, <em>optimizer_type='Adam'</em>, <em>optimizer_epsilon=0.0001</em>, <em>adam_optimizer_beta1=0.9</em>, <em>adam_optimizer_beta2=0.99</em>, <em>rms_prop_optimizer_decay=0.9</em>, <em>batch_size=32</em>, <em>replace_mse_with_huber_loss=False</em>, <em>create_target_network=False</em>, <em>tensorflow_support=True</em>, <em>softmax_temperature=1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#NetworkParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.NetworkParameters" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.base_parameters.</code><code class="sig-name descname">NetworkParameters</code><span class="sig-paren">(</span><em class="sig-param">force_cpu=False</em>, <em class="sig-param">async_training=False</em>, <em class="sig-param">shared_optimizer=True</em>, <em class="sig-param">scale_down_gradients_by_number_of_workers_for_sync_training=True</em>, <em class="sig-param">clip_gradients=None</em>, <em class="sig-param">gradients_clipping_method=<GradientClippingMethod.ClipByGlobalNorm: 0></em>, <em class="sig-param">l2_regularization=0</em>, <em class="sig-param">learning_rate=0.00025</em>, <em class="sig-param">learning_rate_decay_rate=0</em>, <em class="sig-param">learning_rate_decay_steps=0</em>, <em class="sig-param">input_embedders_parameters={}</em>, <em class="sig-param">embedding_merger_type=<EmbeddingMergerType.Concat: 0></em>, <em class="sig-param">middleware_parameters=None</em>, <em class="sig-param">heads_parameters=[]</em>, <em class="sig-param">use_separate_networks_per_head=False</em>, <em class="sig-param">optimizer_type='Adam'</em>, <em class="sig-param">optimizer_epsilon=0.0001</em>, <em class="sig-param">adam_optimizer_beta1=0.9</em>, <em class="sig-param">adam_optimizer_beta2=0.99</em>, <em class="sig-param">rms_prop_optimizer_decay=0.9</em>, <em class="sig-param">batch_size=32</em>, <em class="sig-param">replace_mse_with_huber_loss=False</em>, <em class="sig-param">create_target_network=False</em>, <em class="sig-param">tensorflow_support=True</em>, <em class="sig-param">softmax_temperature=1</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/base_parameters.html#NetworkParameters"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.base_parameters.NetworkParameters" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
<dd class="field-odd"><ul class="simple">
|
||||
@@ -268,7 +268,7 @@ online network at will.</p></li>
|
||||
<h2>Architecture<a class="headerlink" href="#architecture" title="Permalink to this headline">¶</a></h2>
|
||||
<dl class="class">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture">
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.architectures.architecture.</code><code class="descname">Architecture</code><span class="sig-paren">(</span><em>agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em>spaces: rl_coach.spaces.SpacesDefinition</em>, <em>name: str = ''</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.architectures.architecture.</code><code class="sig-name descname">Architecture</code><span class="sig-paren">(</span><em class="sig-param">agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em class="sig-param">spaces: rl_coach.spaces.SpacesDefinition</em>, <em class="sig-param">name: str = ''</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Creates a neural network ‘architecture’, that can be trained and used for inference.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
@@ -281,7 +281,7 @@ online network at will.</p></li>
|
||||
</dl>
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.accumulate_gradients">
|
||||
<code class="descname">accumulate_gradients</code><span class="sig-paren">(</span><em>inputs: Dict[str, numpy.ndarray], targets: List[numpy.ndarray], additional_fetches: list = None, importance_weights: numpy.ndarray = None, no_accumulation: bool = False</em><span class="sig-paren">)</span> → Tuple[float, List[float], float, list]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.accumulate_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.accumulate_gradients" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">accumulate_gradients</code><span class="sig-paren">(</span><em class="sig-param">inputs: Dict[str, numpy.ndarray], targets: List[numpy.ndarray], additional_fetches: list = None, importance_weights: numpy.ndarray = None, no_accumulation: bool = False</em><span class="sig-paren">)</span> → Tuple[float, List[float], float, list]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.accumulate_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.accumulate_gradients" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Given a batch of inputs (i.e. states) and targets (e.g. discounted rewards), computes and accumulates the
|
||||
gradients for model parameters. Will run forward and backward pass to compute gradients, clip the gradient
|
||||
values if required and then accumulate gradients from all learners. It does not update the model weights,
|
||||
@@ -324,7 +324,7 @@ fetched_tensors: all values for additional_fetches</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.apply_and_reset_gradients">
|
||||
<code class="descname">apply_and_reset_gradients</code><span class="sig-paren">(</span><em>gradients: List[numpy.ndarray], scaler: float = 1.0</em><span class="sig-paren">)</span> → None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.apply_and_reset_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.apply_and_reset_gradients" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">apply_and_reset_gradients</code><span class="sig-paren">(</span><em class="sig-param">gradients: List[numpy.ndarray], scaler: float = 1.0</em><span class="sig-paren">)</span> → None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.apply_and_reset_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.apply_and_reset_gradients" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Applies the given gradients to the network weights and resets the gradient accumulations.
|
||||
Has the same impact as calling <cite>apply_gradients</cite>, then <cite>reset_accumulated_gradients</cite>.</p>
|
||||
<dl class="field-list simple">
|
||||
@@ -340,7 +340,7 @@ of an identical network (either self or another identical network)</p></li>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.apply_gradients">
|
||||
<code class="descname">apply_gradients</code><span class="sig-paren">(</span><em>gradients: List[numpy.ndarray], scaler: float = 1.0</em><span class="sig-paren">)</span> → None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.apply_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.apply_gradients" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">apply_gradients</code><span class="sig-paren">(</span><em class="sig-param">gradients: List[numpy.ndarray], scaler: float = 1.0</em><span class="sig-paren">)</span> → None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.apply_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.apply_gradients" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Applies the given gradients to the network weights.
|
||||
Will be performed sync or async depending on <cite>network_parameters.async_training</cite></p>
|
||||
<dl class="field-list simple">
|
||||
@@ -356,7 +356,7 @@ of an identical network (either self or another identical network)</p></li>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.collect_savers">
|
||||
<code class="descname">collect_savers</code><span class="sig-paren">(</span><em>parent_path_suffix: str</em><span class="sig-paren">)</span> → rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.collect_savers" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">collect_savers</code><span class="sig-paren">(</span><em class="sig-param">parent_path_suffix: str</em><span class="sig-paren">)</span> → rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.collect_savers" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Collection of all savers for the network (typically only one saver for network and one for ONNX export)
|
||||
:param parent_path_suffix: path suffix of the parent of the network</p>
|
||||
<blockquote>
|
||||
@@ -369,9 +369,9 @@ of an identical network (either self or another identical network)</p></li>
|
||||
</dl>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="staticmethod">
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.construct">
|
||||
<em class="property">static </em><code class="descname">construct</code><span class="sig-paren">(</span><em>variable_scope: str, devices: List[str], *args, **kwargs</em><span class="sig-paren">)</span> → rl_coach.architectures.architecture.Architecture<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.construct"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.construct" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">static </em><code class="sig-name descname">construct</code><span class="sig-paren">(</span><em class="sig-param">variable_scope: str, devices: List[str], *args, **kwargs</em><span class="sig-paren">)</span> → rl_coach.architectures.architecture.Architecture<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.construct"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.construct" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Construct a network class using the provided variable scope and on requested devices
|
||||
:param variable_scope: string specifying variable scope under which to create network variables
|
||||
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
|
||||
@@ -382,7 +382,7 @@ of an identical network (either self or another identical network)</p></li>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.get_variable_value">
|
||||
<code class="descname">get_variable_value</code><span class="sig-paren">(</span><em>variable: Any</em><span class="sig-paren">)</span> → numpy.ndarray<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.get_variable_value"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.get_variable_value" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">get_variable_value</code><span class="sig-paren">(</span><em class="sig-param">variable: Any</em><span class="sig-paren">)</span> → numpy.ndarray<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.get_variable_value"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.get_variable_value" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Gets value of a specified variable. Type of variable is dependant on the framework.
|
||||
Example of a variable is head.kl_coefficient, which could be a symbol for evaluation
|
||||
or could be a string representing the value.</p>
|
||||
@@ -398,7 +398,7 @@ or could be a string representing the value.</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.get_weights">
|
||||
<code class="descname">get_weights</code><span class="sig-paren">(</span><span class="sig-paren">)</span> → List[numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.get_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.get_weights" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">get_weights</code><span class="sig-paren">(</span><span class="sig-paren">)</span> → List[numpy.ndarray]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.get_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.get_weights" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Gets model weights as a list of ndarrays. It is used for synchronizing weight between two identical networks.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Returns</dt>
|
||||
@@ -407,9 +407,9 @@ or could be a string representing the value.</p>
|
||||
</dl>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="staticmethod">
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.parallel_predict">
|
||||
<em class="property">static </em><code class="descname">parallel_predict</code><span class="sig-paren">(</span><em>sess: Any, network_input_tuples: List[Tuple[Architecture, Dict[str, numpy.ndarray]]]</em><span class="sig-paren">)</span> → Tuple[numpy.ndarray, ...]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.parallel_predict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.parallel_predict" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">static </em><code class="sig-name descname">parallel_predict</code><span class="sig-paren">(</span><em class="sig-param">sess: Any, network_input_tuples: List[Tuple[Architecture, Dict[str, numpy.ndarray]]]</em><span class="sig-paren">)</span> → Tuple[numpy.ndarray, ...]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.parallel_predict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.parallel_predict" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
<dd class="field-odd"><ul class="simple">
|
||||
@@ -425,7 +425,7 @@ or could be a string representing the value.</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.predict">
|
||||
<code class="descname">predict</code><span class="sig-paren">(</span><em>inputs: Dict[str, numpy.ndarray], outputs: List[Any] = None, squeeze_output: bool = True, initial_feed_dict: Dict[Any, numpy.ndarray] = None</em><span class="sig-paren">)</span> → Tuple[numpy.ndarray, ...]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.predict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.predict" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">predict</code><span class="sig-paren">(</span><em class="sig-param">inputs: Dict[str, numpy.ndarray], outputs: List[Any] = None, squeeze_output: bool = True, initial_feed_dict: Dict[Any, numpy.ndarray] = None</em><span class="sig-paren">)</span> → Tuple[numpy.ndarray, ...]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.predict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.predict" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Given input observations, use the model to make predictions (e.g. action or value).</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
@@ -446,7 +446,7 @@ depends on the framework backend.</p></li>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.reset_accumulated_gradients">
|
||||
<code class="descname">reset_accumulated_gradients</code><span class="sig-paren">(</span><span class="sig-paren">)</span> → None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.reset_accumulated_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.reset_accumulated_gradients" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">reset_accumulated_gradients</code><span class="sig-paren">(</span><span class="sig-paren">)</span> → None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.reset_accumulated_gradients"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.reset_accumulated_gradients" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Sets gradient of all parameters to 0.</p>
|
||||
<p>Once gradients are reset, they must be accessible by <cite>accumulated_gradients</cite> property of this class,
|
||||
which must return a list of numpy ndarrays. Child class must ensure that <cite>accumulated_gradients</cite> is set.</p>
|
||||
@@ -454,7 +454,7 @@ which must return a list of numpy ndarrays. Child class must ensure that <cite>a
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.set_variable_value">
|
||||
<code class="descname">set_variable_value</code><span class="sig-paren">(</span><em>assign_op: Any</em>, <em>value: numpy.ndarray</em>, <em>placeholder: Any</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.set_variable_value"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.set_variable_value" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">set_variable_value</code><span class="sig-paren">(</span><em class="sig-param">assign_op: Any</em>, <em class="sig-param">value: numpy.ndarray</em>, <em class="sig-param">placeholder: Any</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.set_variable_value"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.set_variable_value" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Updates the value of a specified variable. Type of assign_op is dependant on the framework
|
||||
and is a unique identifier for assigning value to a variable. For example an agent may use
|
||||
head.assign_kl_coefficient. There is a one to one mapping between assign_op and placeholder
|
||||
@@ -472,7 +472,7 @@ head.assign_kl_coefficient. There is a one to one mapping between assign_op and
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.set_weights">
|
||||
<code class="descname">set_weights</code><span class="sig-paren">(</span><em>weights: List[numpy.ndarray], rate: float = 1.0</em><span class="sig-paren">)</span> → None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.set_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.set_weights" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">set_weights</code><span class="sig-paren">(</span><em class="sig-param">weights: List[numpy.ndarray], rate: float = 1.0</em><span class="sig-paren">)</span> → None<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.set_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.set_weights" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Sets model weights for provided layer parameters.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
@@ -490,7 +490,7 @@ i.e. new_weight = rate * given_weight + (1 - rate) * old_weight</p></li>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.architecture.Architecture.train_on_batch">
|
||||
<code class="descname">train_on_batch</code><span class="sig-paren">(</span><em>inputs: Dict[str, numpy.ndarray], targets: List[numpy.ndarray], scaler: float = 1.0, additional_fetches: list = None, importance_weights: numpy.ndarray = None</em><span class="sig-paren">)</span> → Tuple[float, List[float], float, list]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.train_on_batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.train_on_batch" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">train_on_batch</code><span class="sig-paren">(</span><em class="sig-param">inputs: Dict[str, numpy.ndarray], targets: List[numpy.ndarray], scaler: float = 1.0, additional_fetches: list = None, importance_weights: numpy.ndarray = None</em><span class="sig-paren">)</span> → Tuple[float, List[float], float, list]<a class="reference internal" href="../../_modules/rl_coach/architectures/architecture.html#Architecture.train_on_batch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.architecture.Architecture.train_on_batch" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a
|
||||
forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to
|
||||
update the weights.
|
||||
@@ -535,7 +535,7 @@ fetched_tensors: all values for additional_fetches</p>
|
||||
<a class="reference internal image-reference" href="../../_images/distributed.png"><img alt="../../_images/distributed.png" class="align-center" src="../../_images/distributed.png" style="width: 600px;" /></a>
|
||||
<dl class="class">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper">
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.architectures.network_wrapper.</code><code class="descname">NetworkWrapper</code><span class="sig-paren">(</span><em>agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em>has_target: bool</em>, <em>has_global: bool</em>, <em>name: str</em>, <em>spaces: rl_coach.spaces.SpacesDefinition</em>, <em>replicated_device=None</em>, <em>worker_device=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper" title="Permalink to this definition">¶</a></dt>
|
||||
<em class="property">class </em><code class="sig-prename descclassname">rl_coach.architectures.network_wrapper.</code><code class="sig-name descname">NetworkWrapper</code><span class="sig-paren">(</span><em class="sig-param">agent_parameters: rl_coach.base_parameters.AgentParameters</em>, <em class="sig-param">has_target: bool</em>, <em class="sig-param">has_global: bool</em>, <em class="sig-param">name: str</em>, <em class="sig-param">spaces: rl_coach.spaces.SpacesDefinition</em>, <em class="sig-param">replicated_device=None</em>, <em class="sig-param">worker_device=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>The network wrapper contains multiple copies of the same network, each one with a different set of weights which is
|
||||
updating in a different time scale. The network wrapper will always contain an online network.
|
||||
It will contain an additional slow updating target network if it was requested by the user,
|
||||
@@ -544,7 +544,7 @@ multi-process distributed mode. The network wrapper contains functionality for m
|
||||
between them.</p>
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks">
|
||||
<code class="descname">apply_gradients_and_sync_networks</code><span class="sig-paren">(</span><em>reset_gradients=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">apply_gradients_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">reset_gradients=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_and_sync_networks" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
networks if necessary</p>
|
||||
<dl class="field-list simple">
|
||||
@@ -559,7 +559,7 @@ complexity for this function by around 10%</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network">
|
||||
<code class="descname">apply_gradients_to_global_network</code><span class="sig-paren">(</span><em>gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_global_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">apply_gradients_to_global_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_global_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_global_network" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Apply gradients from the online network on the global network</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
@@ -573,7 +573,7 @@ complexity for this function by around 10%</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network">
|
||||
<code class="descname">apply_gradients_to_online_network</code><span class="sig-paren">(</span><em>gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">apply_gradients_to_online_network</code><span class="sig-paren">(</span><em class="sig-param">gradients=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.apply_gradients_to_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.apply_gradients_to_online_network" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Apply gradients from the online network on itself</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Returns</dt>
|
||||
@@ -584,7 +584,7 @@ complexity for this function by around 10%</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.collect_savers">
|
||||
<code class="descname">collect_savers</code><span class="sig-paren">(</span><em>parent_path_suffix: str</em><span class="sig-paren">)</span> → rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.collect_savers" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">collect_savers</code><span class="sig-paren">(</span><em class="sig-param">parent_path_suffix: str</em><span class="sig-paren">)</span> → rl_coach.saver.SaverCollection<a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.collect_savers"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.collect_savers" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Collect all of network’s savers for global or online network
|
||||
Note: global, online, and target network are all copies fo the same network which parameters that are</p>
|
||||
<blockquote>
|
||||
@@ -610,7 +610,7 @@ for saving.</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.parallel_prediction">
|
||||
<code class="descname">parallel_prediction</code><span class="sig-paren">(</span><em>network_input_tuples: List[Tuple]</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.parallel_prediction"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.parallel_prediction" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">parallel_prediction</code><span class="sig-paren">(</span><em class="sig-param">network_input_tuples: List[Tuple]</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.parallel_prediction"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.parallel_prediction" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Run several network prediction in parallel. Currently this only supports running each of the network once.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
@@ -625,7 +625,7 @@ target_network or global_network) and the second element is the inputs</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.set_is_training">
|
||||
<code class="descname">set_is_training</code><span class="sig-paren">(</span><em>state: bool</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.set_is_training"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.set_is_training" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">set_is_training</code><span class="sig-paren">(</span><em class="sig-param">state: bool</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.set_is_training"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.set_is_training" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Set the phase of the network between training and testing</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
@@ -639,7 +639,7 @@ target_network or global_network) and the second element is the inputs</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.sync">
|
||||
<code class="descname">sync</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.sync"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.sync" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">sync</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.sync"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.sync" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Initializes the weights of the networks to match each other</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Returns</dt>
|
||||
@@ -650,7 +650,7 @@ target_network or global_network) and the second element is the inputs</p>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks">
|
||||
<code class="descname">train_and_sync_networks</code><span class="sig-paren">(</span><em>inputs</em>, <em>targets</em>, <em>additional_fetches=[]</em>, <em>importance_weights=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.train_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">train_and_sync_networks</code><span class="sig-paren">(</span><em class="sig-param">inputs</em>, <em class="sig-param">targets</em>, <em class="sig-param">additional_fetches=[]</em>, <em class="sig-param">importance_weights=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.train_and_sync_networks"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.train_and_sync_networks" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>A generic training function that enables multi-threading training using a global network if necessary.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
@@ -670,7 +670,7 @@ error of this sample. If it is not given, the samples losses won’t be scaled</
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.update_online_network">
|
||||
<code class="descname">update_online_network</code><span class="sig-paren">(</span><em>rate=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.update_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.update_online_network" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">update_online_network</code><span class="sig-paren">(</span><em class="sig-param">rate=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.update_online_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.update_online_network" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Copy weights: global network >>> online network</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
@@ -681,7 +681,7 @@ error of this sample. If it is not given, the samples losses won’t be scaled</
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.architectures.network_wrapper.NetworkWrapper.update_target_network">
|
||||
<code class="descname">update_target_network</code><span class="sig-paren">(</span><em>rate=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.update_target_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.update_target_network" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="sig-name descname">update_target_network</code><span class="sig-paren">(</span><em class="sig-param">rate=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/rl_coach/architectures/network_wrapper.html#NetworkWrapper.update_target_network"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.architectures.network_wrapper.NetworkWrapper.update_target_network" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Copy weights: online network >>> target network</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
Reference in New Issue
Block a user