mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Add Flatten layer to architectures + make flatten optional in embedders (#483)
Flatten layer required for embedders that mix conv and dense (Cherry picking from #478)
This commit is contained in:
@@ -264,3 +264,26 @@ class NoisyNetDense(layers.NoisyNetDense):
|
||||
@reg_to_tf_class(layers.NoisyNetDense)
|
||||
def to_tf_class():
|
||||
return NoisyNetDense
|
||||
|
||||
|
||||
class Flatten(layers.Flatten):
|
||||
def __init__(self):
|
||||
super(Flatten, self).__init__()
|
||||
|
||||
def __call__(self, input_layer, **kwargs):
|
||||
"""
|
||||
returns a tensorflow flatten layer
|
||||
:param input_layer: previous layer
|
||||
:return: flatten layer
|
||||
"""
|
||||
return tf.contrib.layers.flatten(input_layer)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf_instance(layers.Flatten)
|
||||
def to_tf_instance(base: layers.Flatten):
|
||||
return Flatten()
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf_class(layers.Flatten)
|
||||
def to_tf_class():
|
||||
return Flatten
|
||||
|
||||
Reference in New Issue
Block a user