mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
TD3 (#338)
This commit is contained in:
Binary file not shown.
|
Before Width: | Height: | Size: 59 KiB After Width: | Height: | Size: 60 KiB |
BIN
docs_raw/source/_static/img/design_imgs/td3.png
Normal file
BIN
docs_raw/source/_static/img/design_imgs/td3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 48 KiB |
1
docs_raw/source/algorithms.xml
Normal file
1
docs_raw/source/algorithms.xml
Normal file
@@ -0,0 +1 @@
|
||||
<mxfile modified="2019-06-13T11:04:47.252Z" host="www.draw.io" agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3626.121 Safari/537.36" etag="OGV5teY4xcR0Xj7nvQBA" version="10.7.7" type="device"><diagram id="Fja6IZyvrddIfr-74nt7" name="Page-1">7V1bk5s2FP41O9M+JIMQ18e9NduZbCdp0jR5ysgg26SAXMC73v76CnMVEphdkLGzbB5sDhJgne/7dCQdkQt4HezeRWizvicu9i9Uxd1dwJsLVQVAt+hHannKLIZuZ4ZV5Ll5ocrwyfsP50Ylt249F8dMwYQQP/E2rNEhYYidhLGhKCKPbLEl8dm7btAKc4ZPDvJ569+em6wzq6Urlf0Oe6t1cWeg5GcCVBTODfEaueSxZoK3F/A6IiTJvgW7a+ynjVe0S1bvt5az5YNFOEz6VPh+fW+gL8GXH+p7HP5156+j3dc3mp4/XPJU/GLs0gbID0mUrMmKhMi/raxXEdmGLk4vq9Cjqsx7QjbUCKjxB06Sp9ybaJsQalongZ+fxTsv+ZpWf6vnR99qZ252+ZX3B0/FQZhET7VK6eG3+rmq2v6oqBcnEfmn9F16j+wXpz+ztSVzU0y2kYM7mi8ttYckilY46SgItdLjlCqYBJg+I60YYR8l3gP7JCjH7Kosl1e9jCL0VCuwIV6YxLUrf0gNtEBOP6jl2MvJpxsNiDTLm1ZXefole4LiqPZTKtMeds+AYP6jH5C/zZuBg2QFuNR/j2svwZ82aO+ZRyo7LLiWnu9fE59E+7pwqaf/SiDUzhj7v7QGCZOaPftLaxSMTXGEIifHs8qhSi1R9YCjBO+6ccWjoKigs+4ChXQ8Vsqj5i5R1jXVgYXoiJBT894LnKNwzpj14Rn60FceBqrDIB+rHAHJIsbRA701CYdxkWEQT0DHwoulkE0NErsIW0uHYSEYiXRU9N7qDO0MnnW2zpNOlcY5wDX6zLn+nIM9OTe0Rx7kYyjo9AyfPu7Vgn5ZJfsmyQxp78Sgwfh3S4oTb+K9Py9pAaBvdtXJ4iq/h5ttWv02WGDXxVFxVfrU2YXZm1Fz7QHkUd9SF3Df9x6mvo4tV2Oob4zEfMVqMB9AnvpAhzz3LVnc146Fi3vPdX38SBtrRkQNETaDBxoBC/Cg8njQZOFBHzk4PgMfADYILrnG+EAQBJfkHd0JhgRSQhEpb24+vKMnL52EtrOq3GHkDmEndULScD/j8ZCEuOHa3IR8bxXSQ4c6jfYZ8Cp1qecg/zI/EezlowVvbCzCQG4EgABDacg2FNBUEyBEXsSm/ewR24jxmdkzPrOmjM/MVsr3ZbgtInirbiyiUgNwGGMxvZtF/9gG9EZkm9AQL87uGm6D78hJB24xX/6AltQLyutjXNNeKPwgQNDHLJfYcCQN+lRWQVRL0NGropkWXZaEzBOxQ0TF6ikq9pSiYo0fRwhV5jMK754VN7w+/utqM4aA8C2vAYZo3gfKkgB7loABEmD3lAAwaWBhcxrwlfe673ubuC26rvkOxZtszXPp7VIMXLE8HZ2IYzDPZId30OSHd1C4xCGLdXBm3QDWaT1ZZ05JOmBO6WNQ83Dl70M+ZjxcOXwKHwPjHKIr0DlNcziUskSh1GXbWOpPFK56DtM47L2aeRlz4nkZYxjvwSvXdtB3VJXNSh87u6UctOdw06wD2S2ws7yk7BZ+0JfNz1BbgNFrXF/XtIkX2IsLzylHKrDYcFyzBSlHotUWaEjzjtraj/deXVFEfblDQgclrRMjLx2EncGYy26k9QHeyaZglVsaAeHYOX9nsKxpaKwPBLGRcGkZyPOCiGg/uRcaCbGmIMMSiDIsAZCW8WHNMeqQbK++6V5Z8DdZvld7wtcJLSjOC4Rt05S6LVgcMARhkaz1QW0eyg6Sib7zlNCYVCba8/9mmThBmYAaO7SdXCaKCHKWiZfJhN5TJrRJp7qhKC10lomTlYlmNGEIZsCOKxPDtpi8epnouyCmK2JcHEkm2lfEZpk4QZloRhOTy4Q2RxPH2R0+dP/ni9bPnr07XIFd5eWsn2nGDMH5BQWTvqBAH7a7YxoIqqeDQa1vZqY+iQyqjSVQ07QPyGBneUkYPJtVgiNAaShCDkLAMhurPRl081qNoGqMTo7PCp4kUBcG5WXlkN9uBOZwvmux0WSFQrC7/Lij/mF7jKbpyMzT6ciKPJgT7chgMatTdkwHginV7CovqSMT5V7JkLrxZiAGPVeXaM4SCUEzOXB6kTybASevdSewu1vvm49hqFJCOQ2wotY3lHu22AJ2MsTSDoitAbrKSxJbCa8ZEqrcvRceELXXvuHUMliZsy3BdtNipuEo20314YMOAERg+DjOZYqrxNsSOPWes2amnyhIMREu4vQj+0UH8diZ/vvTpeFr7MyZJdh3KdztLG1vjioJfoEX/jIBBGkrKB8FFVRxhV9nfLKRoAobkaAIoeZREdqeZVG68fMNlWfly0X6EM1OrfH6pkWrL/n9gCM0KLAbgXUxyVTfjCdI+i/W/MZvzh6v0RK+le7YDWexEaQhyBHXRB21tA0xZ7P+cJIjkhzPR9rEKNh12Nhh1Xx35EgDkuZuyOJF9G0DEq3xXM8tf+TdlvSwelV+Vrz6Dwfg7f8=</diagram></mxfile>
|
||||
@@ -21,6 +21,7 @@ A detailed description of those algorithms can be found by navigating to each of
|
||||
imitation/cil
|
||||
policy_optimization/cppo
|
||||
policy_optimization/ddpg
|
||||
policy_optimization/td3
|
||||
policy_optimization/sac
|
||||
other/dfp
|
||||
value_optimization/double_dqn
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
Twin Delayed Deep Deterministic Policy Gradient
|
||||
==================================
|
||||
|
||||
**Actions space:** Continuous
|
||||
|
||||
**References:** `Addressing Function Approximation Error in Actor-Critic Methods <https://arxiv.org/pdf/1802.09477>`_
|
||||
|
||||
Network Structure
|
||||
-----------------
|
||||
|
||||
.. image:: /_static/img/design_imgs/td3.png
|
||||
:align: center
|
||||
|
||||
Algorithm Description
|
||||
---------------------
|
||||
Choosing an action
|
||||
++++++++++++++++++
|
||||
|
||||
Pass the current states through the actor network, and get an action mean vector :math:`\mu`.
|
||||
While in training phase, use a continuous exploration policy, such as a small zero-meaned gaussian noise,
|
||||
to add exploration noise to the action. When testing, use the mean vector :math:`\mu` as-is.
|
||||
|
||||
Training the network
|
||||
++++++++++++++++++++
|
||||
|
||||
Start by sampling a batch of transitions from the experience replay.
|
||||
|
||||
* To train the two **critic networks**, use the following targets:
|
||||
|
||||
:math:`y_t=r(s_t,a_t )+\gamma \cdot \min_{i=1,2} Q_{i}(s_{t+1},\mu(s_{t+1} )+[\mathcal{N}(0,\,\sigma^{2})]^{MAX\_NOISE}_{MIN\_NOISE})`
|
||||
|
||||
First run the actor target network, using the next states as the inputs, and get :math:`\mu (s_{t+1} )`. Then, add a
|
||||
clipped gaussian noise to these actions, and clip the resulting actions to the actions space.
|
||||
Next, run the critic target networks using the next states and :math:`\mu (s_{t+1} )+[\mathcal{N}(0,\,\sigma^{2})]^{MAX\_NOISE}_{MIN\_NOISE}`,
|
||||
and use the minimum between the two critic networks predictions in order to calculate :math:`y_t` according to the
|
||||
equation above. To train the networks, use the current states and actions as the inputs, and :math:`y_t`
|
||||
as the targets.
|
||||
|
||||
* To train the **actor network**, use the following equation:
|
||||
|
||||
:math:`\nabla_{\theta^\mu } J \approx E_{s_t \tilde{} \rho^\beta } [\nabla_a Q_{1}(s,a)|_{s=s_t,a=\mu (s_t ) } \cdot \nabla_{\theta^\mu} \mu(s)|_{s=s_t} ]`
|
||||
|
||||
Use the actor's online network to get the action mean values using the current states as the inputs.
|
||||
Then, use the first critic's online network in order to get the gradients of the critic output with respect to the
|
||||
action mean values :math:`\nabla _a Q_{1}(s,a)|_{s=s_t,a=\mu(s_t ) }`.
|
||||
Using the chain rule, calculate the gradients of the actor's output, with respect to the actor weights,
|
||||
given :math:`\nabla_a Q(s,a)`. Finally, apply those gradients to the actor network.
|
||||
|
||||
The actor's training is done at a slower frequency than the critic's training, in order to allow the critic to better fit the
|
||||
current policy, before exercising the critic in order to train the actor.
|
||||
Following the same, delayed, actor's training cadence, do a soft update of the critic and actor target networks' weights
|
||||
from the online networks.
|
||||
|
||||
|
||||
.. autoclass:: rl_coach.agents.td3_agent.TD3AlgorithmParameters
|
||||
File diff suppressed because one or more lines are too long
@@ -214,6 +214,16 @@ The algorithms are ordered by their release date in descending order.
|
||||
and therefore it is able to use a replay buffer in order to improve sample efficiency.
|
||||
</span>
|
||||
</div>
|
||||
<div class="algorithm continuous off-policy" data-year="201509">
|
||||
<span class="badge">
|
||||
<a href="components/agents/policy_optimization/td3.html">TD3</a>
|
||||
<br>
|
||||
Very similar to DDPG, i.e. an actor-critic for continuous action spaces, that uses a replay buffer in
|
||||
order to improve sample efficiency. TD3 uses two critic networks in order to mitigate the overestimation
|
||||
in the Q state-action value prediction, slows down the actor updates in order to increase stability and
|
||||
adds noise to actions while training the critic in order to smooth out the critic's predictions.
|
||||
</span>
|
||||
</div>
|
||||
<div class="algorithm continuous discrete on-policy" data-year="201706">
|
||||
<span class="badge">
|
||||
<a href="components/agents/policy_optimization/ppo.html">PPO</a>
|
||||
|
||||
Reference in New Issue
Block a user