1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Add Flatten layer to architectures + make flatten optional in embedders (#483)

Flatten layer required for embedders that mix conv and dense
(Cherry picking from #478)
This commit is contained in:
Guy Jacob
2021-05-12 11:11:10 +03:00
committed by GitHub
parent c369984c2e
commit 235a259223
7 changed files with 49 additions and 10 deletions

View File

@@ -23,7 +23,7 @@ MOD_NAMES = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder', 'tensor': 'Te
class InputEmbedderParameters(NetworkComponentParameters):
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium,
batchnorm: bool=False, dropout_rate: float=0.0, name: str='embedder', input_rescaling=None,
input_offset=None, input_clipping=None, dense_layer=None, is_training=False):
input_offset=None, input_clipping=None, dense_layer=None, is_training=False, flatten=True):
super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function
self.scheme = scheme
@@ -40,6 +40,7 @@ class InputEmbedderParameters(NetworkComponentParameters):
self.input_clipping = input_clipping
self.name = name
self.is_training = is_training
self.flatten = flatten
def path(self, emb_type):
return 'rl_coach.architectures.tensorflow_components.embedders:' + MOD_NAMES[emb_type]

View File

@@ -76,3 +76,11 @@ class NoisyNetDense(object):
def __str__(self):
return "Noisy Dense (num outputs = {})".format(self.units)
class Flatten(object):
"""
Base class for framework specific flatten layer (used to convert 3D convolution output to 1D dense input)
"""
def __str__(self):
return "Flatten"

View File

@@ -20,7 +20,8 @@ import copy
import numpy as np
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense
from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense, \
Flatten
from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
from rl_coach.core_types import InputEmbedding
@@ -36,7 +37,7 @@ class InputEmbedder(object):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense,
is_training=False):
is_training=False, flatten=True):
self.name = name
self.input_size = input_size
self.activation_function = activation_function
@@ -55,6 +56,7 @@ class InputEmbedder(object):
if self.dense_layer is None:
self.dense_layer = Dense
self.is_training = is_training
self.flatten = flatten
# layers order is conv -> batchnorm -> activation -> dropout
if isinstance(self.scheme, EmbedderScheme):
@@ -116,7 +118,10 @@ class InputEmbedder(object):
is_training=self.is_training)
))
self.output = tf.contrib.layers.flatten(self.layers[-1])
if self.flatten:
self.output = Flatten()(self.layers[-1])
else:
self.output = self.layers[-1]
@property
def input_size(self) -> List[int]:

View File

@@ -34,9 +34,10 @@ class ImageEmbedder(InputEmbedder):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense, is_training=False):
dense_layer=Dense, is_training=False, flatten=True):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training,
flatten=flatten)
self.return_type = InputImageEmbedding
if len(input_size) != 3 and scheme != EmbedderScheme.Empty:
raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}"

View File

@@ -41,9 +41,10 @@ class TensorEmbedder(InputEmbedder):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense, is_training=False):
dense_layer=Dense, is_training=False, flatten=True):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training,
flatten=flatten)
self.return_type = InputTensorEmbedding
assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder"

View File

@@ -33,10 +33,10 @@ class VectorEmbedder(InputEmbedder):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense, is_training=False):
dense_layer=Dense, is_training=False, flatten=True):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name,
input_rescaling, input_offset, input_clipping, dense_layer=dense_layer,
is_training=is_training)
is_training=is_training, flatten=flatten)
self.return_type = InputVectorEmbedding
if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty:

View File

@@ -264,3 +264,26 @@ class NoisyNetDense(layers.NoisyNetDense):
@reg_to_tf_class(layers.NoisyNetDense)
def to_tf_class():
return NoisyNetDense
class Flatten(layers.Flatten):
def __init__(self):
super(Flatten, self).__init__()
def __call__(self, input_layer, **kwargs):
"""
returns a tensorflow flatten layer
:param input_layer: previous layer
:return: flatten layer
"""
return tf.contrib.layers.flatten(input_layer)
@staticmethod
@reg_to_tf_instance(layers.Flatten)
def to_tf_instance(base: layers.Flatten):
return Flatten()
@staticmethod
@reg_to_tf_class(layers.Flatten)
def to_tf_class():
return Flatten