mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Add Flatten layer to architectures + make flatten optional in embedders (#483)
Flatten layer required for embedders that mix conv and dense (Cherry picking from #478)
This commit is contained in:
@@ -23,7 +23,7 @@ MOD_NAMES = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder', 'tensor': 'Te
|
||||
class InputEmbedderParameters(NetworkComponentParameters):
|
||||
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium,
|
||||
batchnorm: bool=False, dropout_rate: float=0.0, name: str='embedder', input_rescaling=None,
|
||||
input_offset=None, input_clipping=None, dense_layer=None, is_training=False):
|
||||
input_offset=None, input_clipping=None, dense_layer=None, is_training=False, flatten=True):
|
||||
super().__init__(dense_layer=dense_layer)
|
||||
self.activation_function = activation_function
|
||||
self.scheme = scheme
|
||||
@@ -40,6 +40,7 @@ class InputEmbedderParameters(NetworkComponentParameters):
|
||||
self.input_clipping = input_clipping
|
||||
self.name = name
|
||||
self.is_training = is_training
|
||||
self.flatten = flatten
|
||||
|
||||
def path(self, emb_type):
|
||||
return 'rl_coach.architectures.tensorflow_components.embedders:' + MOD_NAMES[emb_type]
|
||||
|
||||
@@ -76,3 +76,11 @@ class NoisyNetDense(object):
|
||||
|
||||
def __str__(self):
|
||||
return "Noisy Dense (num outputs = {})".format(self.units)
|
||||
|
||||
|
||||
class Flatten(object):
|
||||
"""
|
||||
Base class for framework specific flatten layer (used to convert 3D convolution output to 1D dense input)
|
||||
"""
|
||||
def __str__(self):
|
||||
return "Flatten"
|
||||
|
||||
@@ -20,7 +20,8 @@ import copy
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense, \
|
||||
Flatten
|
||||
from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
|
||||
|
||||
from rl_coach.core_types import InputEmbedding
|
||||
@@ -36,7 +37,7 @@ class InputEmbedder(object):
|
||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
|
||||
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense,
|
||||
is_training=False):
|
||||
is_training=False, flatten=True):
|
||||
self.name = name
|
||||
self.input_size = input_size
|
||||
self.activation_function = activation_function
|
||||
@@ -55,6 +56,7 @@ class InputEmbedder(object):
|
||||
if self.dense_layer is None:
|
||||
self.dense_layer = Dense
|
||||
self.is_training = is_training
|
||||
self.flatten = flatten
|
||||
|
||||
# layers order is conv -> batchnorm -> activation -> dropout
|
||||
if isinstance(self.scheme, EmbedderScheme):
|
||||
@@ -116,7 +118,10 @@ class InputEmbedder(object):
|
||||
is_training=self.is_training)
|
||||
))
|
||||
|
||||
self.output = tf.contrib.layers.flatten(self.layers[-1])
|
||||
if self.flatten:
|
||||
self.output = Flatten()(self.layers[-1])
|
||||
else:
|
||||
self.output = self.layers[-1]
|
||||
|
||||
@property
|
||||
def input_size(self) -> List[int]:
|
||||
|
||||
@@ -34,9 +34,10 @@ class ImageEmbedder(InputEmbedder):
|
||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
|
||||
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None,
|
||||
dense_layer=Dense, is_training=False):
|
||||
dense_layer=Dense, is_training=False, flatten=True):
|
||||
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
|
||||
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
|
||||
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training,
|
||||
flatten=flatten)
|
||||
self.return_type = InputImageEmbedding
|
||||
if len(input_size) != 3 and scheme != EmbedderScheme.Empty:
|
||||
raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}"
|
||||
|
||||
@@ -41,9 +41,10 @@ class TensorEmbedder(InputEmbedder):
|
||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
|
||||
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
|
||||
dense_layer=Dense, is_training=False):
|
||||
dense_layer=Dense, is_training=False, flatten=True):
|
||||
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
|
||||
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
|
||||
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training,
|
||||
flatten=flatten)
|
||||
self.return_type = InputTensorEmbedding
|
||||
assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder"
|
||||
|
||||
|
||||
@@ -33,10 +33,10 @@ class VectorEmbedder(InputEmbedder):
|
||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
|
||||
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
|
||||
dense_layer=Dense, is_training=False):
|
||||
dense_layer=Dense, is_training=False, flatten=True):
|
||||
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name,
|
||||
input_rescaling, input_offset, input_clipping, dense_layer=dense_layer,
|
||||
is_training=is_training)
|
||||
is_training=is_training, flatten=flatten)
|
||||
|
||||
self.return_type = InputVectorEmbedding
|
||||
if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty:
|
||||
|
||||
@@ -264,3 +264,26 @@ class NoisyNetDense(layers.NoisyNetDense):
|
||||
@reg_to_tf_class(layers.NoisyNetDense)
|
||||
def to_tf_class():
|
||||
return NoisyNetDense
|
||||
|
||||
|
||||
class Flatten(layers.Flatten):
|
||||
def __init__(self):
|
||||
super(Flatten, self).__init__()
|
||||
|
||||
def __call__(self, input_layer, **kwargs):
|
||||
"""
|
||||
returns a tensorflow flatten layer
|
||||
:param input_layer: previous layer
|
||||
:return: flatten layer
|
||||
"""
|
||||
return tf.contrib.layers.flatten(input_layer)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf_instance(layers.Flatten)
|
||||
def to_tf_instance(base: layers.Flatten):
|
||||
return Flatten()
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf_class(layers.Flatten)
|
||||
def to_tf_class():
|
||||
return Flatten
|
||||
|
||||
Reference in New Issue
Block a user