mirror of
https://github.com/gryf/coach.git
synced 2026-04-03 10:43:33 +02:00
network_imporvements branch merge
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
@@ -25,9 +25,13 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class NAFHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='naf_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='naf_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=NAFHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class NAFHead(Head):
|
||||
@@ -90,3 +94,21 @@ class NAFHead(Head):
|
||||
self.Q = tf.add(self.V, self.A, name='Q')
|
||||
|
||||
self.output = self.Q
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"State Value Stream - V",
|
||||
"\tDense (num outputs = 1)",
|
||||
"Action Advantage Stream - A",
|
||||
"\tDense (num outputs = {})".format((self.num_actions * (self.num_actions + 1)) / 2),
|
||||
"\tReshape to lower triangular matrix L (new size = {} x {})".format(self.num_actions, self.num_actions),
|
||||
"\tP = L*L^T",
|
||||
"\tA = -1/2 * (u - mu)^T * P * (u - mu)",
|
||||
"Action Stream - mu",
|
||||
"\tDense (num outputs = {})".format(self.num_actions),
|
||||
"\tActivation (type = {})".format(self.activation_function.__name__),
|
||||
"\tMultiply (factor = {})".format(self.output_scale),
|
||||
"State-Action Value Stream - Q",
|
||||
"\tAdd (V, A)"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
Reference in New Issue
Block a user