mirror of
https://github.com/gryf/coach.git
synced 2026-02-23 18:45:50 +01:00
batchnorm fixes + disabling batchnorm in DDPG (#353)
Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
@@ -26,9 +26,9 @@ from rl_coach.spaces import SpacesDefinition
|
||||
class DDPGActor(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
|
||||
batchnorm: bool=True, dense_layer=Dense):
|
||||
batchnorm: bool=True, dense_layer=Dense, is_training=False):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, is_training=is_training)
|
||||
self.name = 'ddpg_actor_head'
|
||||
self.return_type = ActionProbabilities
|
||||
|
||||
@@ -50,7 +50,7 @@ class DDPGActor(Head):
|
||||
batchnorm=self.batchnorm,
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=0,
|
||||
is_training=False,
|
||||
is_training=self.is_training,
|
||||
name="BatchnormActivationDropout_0")[-1]
|
||||
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class Head(object):
|
||||
"""
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
dense_layer=Dense, is_training=False):
|
||||
self.head_idx = head_idx
|
||||
self.network_name = network_name
|
||||
self.network_parameters = agent_parameters.network_wrappers[self.network_name]
|
||||
@@ -64,6 +64,7 @@ class Head(object):
|
||||
self.dense_layer = Dense
|
||||
else:
|
||||
self.dense_layer = convert_layer_class(self.dense_layer)
|
||||
self.is_training = is_training
|
||||
|
||||
def __call__(self, input_layer):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user