1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Add Flatten layer to architectures + make flatten optional in embedders (#483)

Flatten layer required for embedders that mix conv and dense
(Cherry picking from #478)
This commit is contained in:
Guy Jacob
2021-05-12 11:11:10 +03:00
committed by GitHub
parent c369984c2e
commit 235a259223
7 changed files with 49 additions and 10 deletions

View File

@@ -23,7 +23,7 @@ MOD_NAMES = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder', 'tensor': 'Te
class InputEmbedderParameters(NetworkComponentParameters): class InputEmbedderParameters(NetworkComponentParameters):
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium, def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium,
batchnorm: bool=False, dropout_rate: float=0.0, name: str='embedder', input_rescaling=None, batchnorm: bool=False, dropout_rate: float=0.0, name: str='embedder', input_rescaling=None,
input_offset=None, input_clipping=None, dense_layer=None, is_training=False): input_offset=None, input_clipping=None, dense_layer=None, is_training=False, flatten=True):
super().__init__(dense_layer=dense_layer) super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function self.activation_function = activation_function
self.scheme = scheme self.scheme = scheme
@@ -40,6 +40,7 @@ class InputEmbedderParameters(NetworkComponentParameters):
self.input_clipping = input_clipping self.input_clipping = input_clipping
self.name = name self.name = name
self.is_training = is_training self.is_training = is_training
self.flatten = flatten
def path(self, emb_type): def path(self, emb_type):
return 'rl_coach.architectures.tensorflow_components.embedders:' + MOD_NAMES[emb_type] return 'rl_coach.architectures.tensorflow_components.embedders:' + MOD_NAMES[emb_type]

View File

@@ -76,3 +76,11 @@ class NoisyNetDense(object):
def __str__(self): def __str__(self):
return "Noisy Dense (num outputs = {})".format(self.units) return "Noisy Dense (num outputs = {})".format(self.units)
class Flatten(object):
"""
Base class for framework specific flatten layer (used to convert 3D convolution output to 1D dense input)
"""
def __str__(self):
return "Flatten"

View File

@@ -20,7 +20,8 @@ import copy
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense, \
Flatten
from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
from rl_coach.core_types import InputEmbedding from rl_coach.core_types import InputEmbedding
@@ -36,7 +37,7 @@ class InputEmbedder(object):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu, def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0, scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense, name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense,
is_training=False): is_training=False, flatten=True):
self.name = name self.name = name
self.input_size = input_size self.input_size = input_size
self.activation_function = activation_function self.activation_function = activation_function
@@ -55,6 +56,7 @@ class InputEmbedder(object):
if self.dense_layer is None: if self.dense_layer is None:
self.dense_layer = Dense self.dense_layer = Dense
self.is_training = is_training self.is_training = is_training
self.flatten = flatten
# layers order is conv -> batchnorm -> activation -> dropout # layers order is conv -> batchnorm -> activation -> dropout
if isinstance(self.scheme, EmbedderScheme): if isinstance(self.scheme, EmbedderScheme):
@@ -116,7 +118,10 @@ class InputEmbedder(object):
is_training=self.is_training) is_training=self.is_training)
)) ))
self.output = tf.contrib.layers.flatten(self.layers[-1]) if self.flatten:
self.output = Flatten()(self.layers[-1])
else:
self.output = self.layers[-1]
@property @property
def input_size(self) -> List[int]: def input_size(self) -> List[int]:

View File

@@ -34,9 +34,10 @@ class ImageEmbedder(InputEmbedder):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu, def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0, scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None, name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense, is_training=False): dense_layer=Dense, is_training=False, flatten=True):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling, super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training) input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training,
flatten=flatten)
self.return_type = InputImageEmbedding self.return_type = InputImageEmbedding
if len(input_size) != 3 and scheme != EmbedderScheme.Empty: if len(input_size) != 3 and scheme != EmbedderScheme.Empty:
raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}" raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}"

View File

@@ -41,9 +41,10 @@ class TensorEmbedder(InputEmbedder):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu, def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0, scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None, name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense, is_training=False): dense_layer=Dense, is_training=False, flatten=True):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling, super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training) input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training,
flatten=flatten)
self.return_type = InputTensorEmbedding self.return_type = InputTensorEmbedding
assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder" assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder"

View File

@@ -33,10 +33,10 @@ class VectorEmbedder(InputEmbedder):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu, def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0, scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None, name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense, is_training=False): dense_layer=Dense, is_training=False, flatten=True):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name,
input_rescaling, input_offset, input_clipping, dense_layer=dense_layer, input_rescaling, input_offset, input_clipping, dense_layer=dense_layer,
is_training=is_training) is_training=is_training, flatten=flatten)
self.return_type = InputVectorEmbedding self.return_type = InputVectorEmbedding
if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty: if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty:

View File

@@ -264,3 +264,26 @@ class NoisyNetDense(layers.NoisyNetDense):
@reg_to_tf_class(layers.NoisyNetDense) @reg_to_tf_class(layers.NoisyNetDense)
def to_tf_class(): def to_tf_class():
return NoisyNetDense return NoisyNetDense
class Flatten(layers.Flatten):
def __init__(self):
super(Flatten, self).__init__()
def __call__(self, input_layer, **kwargs):
"""
returns a tensorflow flatten layer
:param input_layer: previous layer
:return: flatten layer
"""
return tf.contrib.layers.flatten(input_layer)
@staticmethod
@reg_to_tf_instance(layers.Flatten)
def to_tf_instance(base: layers.Flatten):
return Flatten()
@staticmethod
@reg_to_tf_class(layers.Flatten)
def to_tf_class():
return Flatten