mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
This commit is contained in:
@@ -190,13 +190,15 @@
|
||||
<em class="property">class </em><code class="descclassname">rl_coach.agents.dqn_agent.</code><code class="descname">DQNAgent</code><span class="sig-paren">(</span><em>agent_parameters</em>, <em>parent: Union[LevelManager</em>, <em>CompositeAgent] = None</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/rl_coach/agents/dqn_agent.html#DQNAgent"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.act">
|
||||
<code class="descname">act</code><span class="sig-paren">(</span><span class="sig-paren">)</span> → rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="descname">act</code><span class="sig-paren">(</span><em>action: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List] = None</em><span class="sig-paren">)</span> → rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Given the agents current knowledge, decide on the next action to apply to the environment</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
<tbody valign="top">
|
||||
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
|
||||
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>action</strong> – An action to take, overriding whatever the current policy is</td>
|
||||
</tr>
|
||||
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -267,26 +269,6 @@ for creating the network.</p>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer">
|
||||
<code class="descname">emulate_act_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> → rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Given the agents current knowledge, decide on the next action to apply to the environment
|
||||
:return: an action and a dictionary containing any additional info from the action decision process</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer">
|
||||
<code class="descname">emulate_observe_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> → bool<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>This emulates the observe using the transition obtained from the rollout worker on the training worker
|
||||
in case of distributed training.
|
||||
Given a response from the environment, distill the observation from it and store it for later use.
|
||||
The response should be a dictionary containing the performed action, the new observation and measurements,
|
||||
the reward, a game over flag and any additional information necessary.
|
||||
:return:</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.get_predictions">
|
||||
<code class="descname">get_predictions</code><span class="sig-paren">(</span><em>states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.get_predictions" title="Permalink to this definition">¶</a></dt>
|
||||
@@ -342,6 +324,22 @@ This function is called right after each episode is ended.</p>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model">
|
||||
<code class="descname">improve_reward_model</code><span class="sig-paren">(</span><em>epochs: int</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Train a reward model to be used by the doubly-robust estimator</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
<col class="field-name" />
|
||||
<col class="field-body" />
|
||||
<tbody valign="top">
|
||||
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>epochs</strong> – The total number of epochs to use for training a reward model</td>
|
||||
</tr>
|
||||
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">None</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules">
|
||||
<code class="descname">init_environment_dependent_modules</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules" title="Permalink to this definition">¶</a></dt>
|
||||
@@ -450,7 +448,7 @@ given observation</td>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference">
|
||||
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> → Dict[str, numpy.core.multiarray.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition">¶</a></dt>
|
||||
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> → Dict[str, numpy.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
|
||||
observations together, measurements together, etc.</p>
|
||||
<table class="docutils field-list" frame="void" rules="none">
|
||||
@@ -542,6 +540,14 @@ by val, and by the current phase set in self.phase.</p>
|
||||
</table>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation">
|
||||
<code class="descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on
|
||||
an evaluation dataset, which was collected by another policy(ies).
|
||||
:return: None</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="method">
|
||||
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference">
|
||||
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> → Dict[str, numpy.ndarray]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference" title="Permalink to this definition">¶</a></dt>
|
||||
|
||||
Reference in New Issue
Block a user