1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

SAC algorithm (#282)

* SAC algorithm

* SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train.
gym_environment - fixing an error in access to gym.spaces

* Soft Actor Critic - code cleanup

* code cleanup

* V-head initialization fix

* SAC benchmarks

* SAC Documentation

* typo fix

* documentation fixes

* documentation and version update

* README typo
This commit is contained in:
guyk1971
2019-05-01 18:37:49 +03:00
committed by shadiendrawis
parent 33dc29ee99
commit 74db141d5e
92 changed files with 2812 additions and 402 deletions

View File

@@ -190,13 +190,15 @@
<em class="property">class </em><code class="descclassname">rl_coach.agents.dqn_agent.</code><code class="descname">DQNAgent</code><span class="sig-paren">(</span><em>agent_parameters</em>, <em>parent: Union[LevelManager</em>, <em>CompositeAgent] = None</em><span class="sig-paren">)</span><a class="reference internal" href="_modules/rl_coach/agents/dqn_agent.html#DQNAgent"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent" title="Permalink to this definition"></a></dt>
<dd><dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.act">
<code class="descname">act</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition"></a></dt>
<code class="descname">act</code><span class="sig-paren">(</span><em>action: Union[None</em>, <em>int</em>, <em>float</em>, <em>numpy.ndarray</em>, <em>List] = None</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.act" title="Permalink to this definition"></a></dt>
<dd><p>Given the agents current knowledge, decide on the next action to apply to the environment</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>action</strong> An action to take, overriding whatever the current policy is</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">An ActionInfo object, which contains the action and any additional info from the action decision process</td>
</tr>
</tbody>
</table>
@@ -267,26 +269,6 @@ for creating the network.</p>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer">
<code class="descname">emulate_act_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; rl_coach.core_types.ActionInfo<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.emulate_act_on_trainer" title="Permalink to this definition"></a></dt>
<dd><p>This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given the agents current knowledge, decide on the next action to apply to the environment
:return: an action and a dictionary containing any additional info from the action decision process</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer">
<code class="descname">emulate_observe_on_trainer</code><span class="sig-paren">(</span><em>transition: rl_coach.core_types.Transition</em><span class="sig-paren">)</span> &#x2192; bool<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.emulate_observe_on_trainer" title="Permalink to this definition"></a></dt>
<dd><p>This emulates the observe using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given a response from the environment, distill the observation from it and store it for later use.
The response should be a dictionary containing the performed action, the new observation and measurements,
the reward, a game over flag and any additional information necessary.
:return:</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.get_predictions">
<code class="descname">get_predictions</code><span class="sig-paren">(</span><em>states: List[Dict[str, numpy.ndarray]], prediction_type: rl_coach.core_types.PredictionType</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.get_predictions" title="Permalink to this definition"></a></dt>
@@ -342,6 +324,22 @@ This function is called right after each episode is ended.</p>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model">
<code class="descname">improve_reward_model</code><span class="sig-paren">(</span><em>epochs: int</em><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.improve_reward_model" title="Permalink to this definition"></a></dt>
<dd><p>Train a reward model to be used by the doubly-robust estimator</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>epochs</strong> The total number of epochs to use for training a reward model</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body">None</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules">
<code class="descname">init_environment_dependent_modules</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.init_environment_dependent_modules" title="Permalink to this definition"></a></dt>
@@ -450,7 +448,7 @@ given observation</td>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference">
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.core.multiarray.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<code class="descname">prepare_batch_for_inference</code><span class="sig-paren">(</span><em>states: Union[Dict[str, numpy.ndarray], List[Dict[str, numpy.ndarray]]], network_name: str</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.array]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.prepare_batch_for_inference" title="Permalink to this definition"></a></dt>
<dd><p>Convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
observations together, measurements together, etc.</p>
<table class="docutils field-list" frame="void" rules="none">
@@ -542,6 +540,14 @@ by val, and by the current phase set in self.phase.</p>
</table>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation">
<code class="descname">run_off_policy_evaluation</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_off_policy_evaluation" title="Permalink to this definition"></a></dt>
<dd><p>Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on
an evaluation dataset, which was collected by another policy(ies).
:return: None</p>
</dd></dl>
<dl class="method">
<dt id="rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference">
<code class="descname">run_pre_network_filter_for_inference</code><span class="sig-paren">(</span><em>state: Dict[str, numpy.ndarray], update_filter_internal_state: bool = True</em><span class="sig-paren">)</span> &#x2192; Dict[str, numpy.ndarray]<a class="headerlink" href="#rl_coach.agents.dqn_agent.DQNAgent.run_pre_network_filter_for_inference" title="Permalink to this definition"></a></dt>