mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
@@ -8,7 +8,7 @@ from rl_coach.architectures import layers
|
||||
from rl_coach.architectures.tensorflow_components import utils
|
||||
|
||||
|
||||
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout, dropout_rate, is_training, name):
|
||||
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name):
|
||||
layers = [input_layer]
|
||||
|
||||
# batchnorm
|
||||
@@ -26,7 +26,7 @@ def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dr
|
||||
)
|
||||
|
||||
# dropout
|
||||
if dropout:
|
||||
if dropout_rate > 0:
|
||||
layers.append(
|
||||
tf.layers.dropout(layers[-1], dropout_rate, name="{}_dropout".format(name), training=is_training)
|
||||
)
|
||||
@@ -100,7 +100,7 @@ class BatchnormActivationDropout(layers.BatchnormActivationDropout):
|
||||
"""
|
||||
return batchnorm_activation_dropout(input_layer, batchnorm=self.batchnorm,
|
||||
activation_function=self.activation_function,
|
||||
dropout=self.dropout_rate > 0, dropout_rate=self.dropout_rate,
|
||||
dropout_rate=self.dropout_rate,
|
||||
is_training=is_training, name=name)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user