1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-23 18:45:50 +01:00

batchnorm fixes + disabling batchnorm in DDPG (#353)

Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
Gal Leibovich
2019-06-23 11:28:22 +03:00
committed by GitHub
parent 7b5d6a3f03
commit d6795bd524
22 changed files with 105 additions and 50 deletions

View File

@@ -26,9 +26,9 @@ from rl_coach.spaces import SpacesDefinition
class DDPGActor(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
batchnorm: bool=True, dense_layer=Dense):
batchnorm: bool=True, dense_layer=Dense, is_training=False):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
dense_layer=dense_layer, is_training=is_training)
self.name = 'ddpg_actor_head'
self.return_type = ActionProbabilities
@@ -50,7 +50,7 @@ class DDPGActor(Head):
batchnorm=self.batchnorm,
activation_function=self.activation_function,
dropout_rate=0,
is_training=False,
is_training=self.is_training,
name="BatchnormActivationDropout_0")[-1]
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')

View File

@@ -40,7 +40,7 @@ class Head(object):
"""
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
dense_layer=Dense):
dense_layer=Dense, is_training=False):
self.head_idx = head_idx
self.network_name = network_name
self.network_parameters = agent_parameters.network_wrappers[self.network_name]
@@ -64,6 +64,7 @@ class Head(object):
self.dense_layer = Dense
else:
self.dense_layer = convert_layer_class(self.dense_layer)
self.is_training = is_training
def __call__(self, input_layer):
"""