mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Add Flatten layer to architectures + make flatten optional in embedders (#483)
Flatten layer required for embedders that mix conv and dense (Cherry picking from #478)
This commit is contained in:
@@ -23,7 +23,7 @@ MOD_NAMES = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder', 'tensor': 'Te
|
|||||||
class InputEmbedderParameters(NetworkComponentParameters):
|
class InputEmbedderParameters(NetworkComponentParameters):
|
||||||
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium,
|
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium,
|
||||||
batchnorm: bool=False, dropout_rate: float=0.0, name: str='embedder', input_rescaling=None,
|
batchnorm: bool=False, dropout_rate: float=0.0, name: str='embedder', input_rescaling=None,
|
||||||
input_offset=None, input_clipping=None, dense_layer=None, is_training=False):
|
input_offset=None, input_clipping=None, dense_layer=None, is_training=False, flatten=True):
|
||||||
super().__init__(dense_layer=dense_layer)
|
super().__init__(dense_layer=dense_layer)
|
||||||
self.activation_function = activation_function
|
self.activation_function = activation_function
|
||||||
self.scheme = scheme
|
self.scheme = scheme
|
||||||
@@ -40,6 +40,7 @@ class InputEmbedderParameters(NetworkComponentParameters):
|
|||||||
self.input_clipping = input_clipping
|
self.input_clipping = input_clipping
|
||||||
self.name = name
|
self.name = name
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
|
self.flatten = flatten
|
||||||
|
|
||||||
def path(self, emb_type):
|
def path(self, emb_type):
|
||||||
return 'rl_coach.architectures.tensorflow_components.embedders:' + MOD_NAMES[emb_type]
|
return 'rl_coach.architectures.tensorflow_components.embedders:' + MOD_NAMES[emb_type]
|
||||||
|
|||||||
@@ -76,3 +76,11 @@ class NoisyNetDense(object):
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "Noisy Dense (num outputs = {})".format(self.units)
|
return "Noisy Dense (num outputs = {})".format(self.units)
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(object):
|
||||||
|
"""
|
||||||
|
Base class for framework specific flatten layer (used to convert 3D convolution output to 1D dense input)
|
||||||
|
"""
|
||||||
|
def __str__(self):
|
||||||
|
return "Flatten"
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ import copy
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense
|
from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense, \
|
||||||
|
Flatten
|
||||||
from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
|
from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
|
||||||
|
|
||||||
from rl_coach.core_types import InputEmbedding
|
from rl_coach.core_types import InputEmbedding
|
||||||
@@ -36,7 +37,7 @@ class InputEmbedder(object):
|
|||||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||||
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
|
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
|
||||||
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense,
|
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense,
|
||||||
is_training=False):
|
is_training=False, flatten=True):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.activation_function = activation_function
|
self.activation_function = activation_function
|
||||||
@@ -55,6 +56,7 @@ class InputEmbedder(object):
|
|||||||
if self.dense_layer is None:
|
if self.dense_layer is None:
|
||||||
self.dense_layer = Dense
|
self.dense_layer = Dense
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
|
self.flatten = flatten
|
||||||
|
|
||||||
# layers order is conv -> batchnorm -> activation -> dropout
|
# layers order is conv -> batchnorm -> activation -> dropout
|
||||||
if isinstance(self.scheme, EmbedderScheme):
|
if isinstance(self.scheme, EmbedderScheme):
|
||||||
@@ -116,7 +118,10 @@ class InputEmbedder(object):
|
|||||||
is_training=self.is_training)
|
is_training=self.is_training)
|
||||||
))
|
))
|
||||||
|
|
||||||
self.output = tf.contrib.layers.flatten(self.layers[-1])
|
if self.flatten:
|
||||||
|
self.output = Flatten()(self.layers[-1])
|
||||||
|
else:
|
||||||
|
self.output = self.layers[-1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_size(self) -> List[int]:
|
def input_size(self) -> List[int]:
|
||||||
|
|||||||
@@ -34,9 +34,10 @@ class ImageEmbedder(InputEmbedder):
|
|||||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||||
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
|
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
|
||||||
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None,
|
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None,
|
||||||
dense_layer=Dense, is_training=False):
|
dense_layer=Dense, is_training=False, flatten=True):
|
||||||
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
|
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
|
||||||
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
|
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training,
|
||||||
|
flatten=flatten)
|
||||||
self.return_type = InputImageEmbedding
|
self.return_type = InputImageEmbedding
|
||||||
if len(input_size) != 3 and scheme != EmbedderScheme.Empty:
|
if len(input_size) != 3 and scheme != EmbedderScheme.Empty:
|
||||||
raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}"
|
raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}"
|
||||||
|
|||||||
@@ -41,9 +41,10 @@ class TensorEmbedder(InputEmbedder):
|
|||||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||||
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
|
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
|
||||||
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
|
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
|
||||||
dense_layer=Dense, is_training=False):
|
dense_layer=Dense, is_training=False, flatten=True):
|
||||||
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
|
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
|
||||||
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
|
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training,
|
||||||
|
flatten=flatten)
|
||||||
self.return_type = InputTensorEmbedding
|
self.return_type = InputTensorEmbedding
|
||||||
assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder"
|
assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder"
|
||||||
|
|
||||||
|
|||||||
@@ -33,10 +33,10 @@ class VectorEmbedder(InputEmbedder):
|
|||||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||||
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
|
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0,
|
||||||
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
|
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
|
||||||
dense_layer=Dense, is_training=False):
|
dense_layer=Dense, is_training=False, flatten=True):
|
||||||
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name,
|
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name,
|
||||||
input_rescaling, input_offset, input_clipping, dense_layer=dense_layer,
|
input_rescaling, input_offset, input_clipping, dense_layer=dense_layer,
|
||||||
is_training=is_training)
|
is_training=is_training, flatten=flatten)
|
||||||
|
|
||||||
self.return_type = InputVectorEmbedding
|
self.return_type = InputVectorEmbedding
|
||||||
if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty:
|
if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty:
|
||||||
|
|||||||
@@ -264,3 +264,26 @@ class NoisyNetDense(layers.NoisyNetDense):
|
|||||||
@reg_to_tf_class(layers.NoisyNetDense)
|
@reg_to_tf_class(layers.NoisyNetDense)
|
||||||
def to_tf_class():
|
def to_tf_class():
|
||||||
return NoisyNetDense
|
return NoisyNetDense
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(layers.Flatten):
|
||||||
|
def __init__(self):
|
||||||
|
super(Flatten, self).__init__()
|
||||||
|
|
||||||
|
def __call__(self, input_layer, **kwargs):
|
||||||
|
"""
|
||||||
|
returns a tensorflow flatten layer
|
||||||
|
:param input_layer: previous layer
|
||||||
|
:return: flatten layer
|
||||||
|
"""
|
||||||
|
return tf.contrib.layers.flatten(input_layer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@reg_to_tf_instance(layers.Flatten)
|
||||||
|
def to_tf_instance(base: layers.Flatten):
|
||||||
|
return Flatten()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@reg_to_tf_class(layers.Flatten)
|
||||||
|
def to_tf_class():
|
||||||
|
return Flatten
|
||||||
|
|||||||
Reference in New Issue
Block a user