1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00
This commit is contained in:
Gal Leibovich
2019-06-16 11:11:21 +03:00
committed by GitHub
parent 8df3c46756
commit 7eb884c5b2
107 changed files with 2200 additions and 495 deletions

View File

@@ -28,23 +28,28 @@ class FCMiddleware(Middleware):
def __init__(self, activation_function=tf.nn.relu,
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout_rate: float = 0.0,
name="middleware_fc_embedder", dense_layer=Dense, is_training=False):
name="middleware_fc_embedder", dense_layer=Dense, is_training=False, num_streams: int = 1):
super().__init__(activation_function=activation_function, batchnorm=batchnorm,
dropout_rate=dropout_rate, scheme=scheme, name=name, dense_layer=dense_layer,
is_training=is_training)
self.return_type = Middleware_FC_Embedding
self.layers = []
assert(isinstance(num_streams, int) and num_streams >= 1)
self.num_streams = num_streams
def _build_module(self):
self.layers.append(self.input)
self.output = []
for idx, layer_params in enumerate(self.layers_params):
self.layers.extend(force_list(
layer_params(self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx),
is_training=self.is_training)
))
for stream_idx in range(self.num_streams):
layers = [self.input]
self.output = self.layers[-1]
for idx, layer_params in enumerate(self.layers_params):
layers.extend(force_list(
layer_params(layers[-1], name='{}_{}'.format(layer_params.__class__.__name__,
idx + stream_idx * len(self.layers_params)),
is_training=self.is_training)
))
self.output.append((layers[-1]))
@property
def schemes(self):
@@ -72,3 +77,15 @@ class FCMiddleware(Middleware):
]
}
def __str__(self):
stream = [str(l) for l in self.layers_params]
if self.layers_params:
if self.num_streams > 1:
stream = [''] + ['\t' + l for l in stream]
result = stream * self.num_streams
result[0::len(stream)] = ['Stream {}'.format(i) for i in range(self.num_streams)]
else:
result = stream
return '\n'.join(result)
else:
return 'No layers'