1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Add Flatten layer to architectures + make flatten optional in embedders (#483)

Flatten layer required for embedders that mix conv and dense
(Cherry picking from #478)
This commit is contained in:
Guy Jacob
2021-05-12 11:11:10 +03:00
committed by GitHub
parent c369984c2e
commit 235a259223
7 changed files with 49 additions and 10 deletions

View File

@@ -264,3 +264,26 @@ class NoisyNetDense(layers.NoisyNetDense):
@reg_to_tf_class(layers.NoisyNetDense)
def to_tf_class():
return NoisyNetDense
class Flatten(layers.Flatten):
def __init__(self):
super(Flatten, self).__init__()
def __call__(self, input_layer, **kwargs):
"""
returns a tensorflow flatten layer
:param input_layer: previous layer
:return: flatten layer
"""
return tf.contrib.layers.flatten(input_layer)
@staticmethod
@reg_to_tf_instance(layers.Flatten)
def to_tf_instance(base: layers.Flatten):
return Flatten()
@staticmethod
@reg_to_tf_class(layers.Flatten)
def to_tf_class():
return Flatten