1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Add Flatten layer to architectures + make flatten optional in embedders (#483)

Flatten layer required for embedders that mix conv and dense
(Cherry picking from #478)
This commit is contained in:
Guy Jacob
2021-05-12 11:11:10 +03:00
committed by GitHub
parent c369984c2e
commit 235a259223
7 changed files with 49 additions and 10 deletions

View File

@@ -20,7 +20,8 @@ import copy
import numpy as np
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense
from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense, \
Flatten
from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
from rl_coach.core_types import InputEmbedding
@@ -36,7 +37,7 @@ class InputEmbedder(object):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense,
is_training=False):
is_training=False, flatten=True):
self.name = name
self.input_size = input_size
self.activation_function = activation_function
@@ -55,6 +56,7 @@ class InputEmbedder(object):
if self.dense_layer is None:
self.dense_layer = Dense
self.is_training = is_training
self.flatten = flatten
# layers order is conv -> batchnorm -> activation -> dropout
if isinstance(self.scheme, EmbedderScheme):
@@ -116,7 +118,10 @@ class InputEmbedder(object):
is_training=self.is_training)
))
self.output = tf.contrib.layers.flatten(self.layers[-1])
if self.flatten:
self.output = Flatten()(self.layers[-1])
else:
self.output = self.layers[-1]
@property
def input_size(self) -> List[int]: