1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Move embedder, middleware, and head parameters to framework agnostic modules. (#45)

Part of #28
This commit is contained in:
Sina Afrooze
2018-10-29 14:46:40 -07:00
committed by Scott Leishman
parent 16b3e99f37
commit a888226641
60 changed files with 410 additions and 330 deletions

View File

@@ -33,19 +33,6 @@ def normalized_columns_initializer(std=1.0):
return _initializer
class HeadParameters(NetworkComponentParameters):
def __init__(self, parameterized_class: Type['Head'], activation_function: str = 'relu', name: str= 'head',
num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0,
loss_weight: float=1.0, dense_layer=Dense):
super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function
self.name = name
self.num_output_head_copies = num_output_head_copies
self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor
self.loss_weight = loss_weight
self.parameterized_class_name = parameterized_class.__name__
class Head(object):
"""
A head is the final part of the network. It takes the embedding from the middleware embedder and passes it through
@@ -74,6 +61,8 @@ class Head(object):
self.return_type = None
self.activation_function = activation_function
self.dense_layer = dense_layer
if self.dense_layer is None:
self.dense_layer = Dense
def __call__(self, input_layer):
"""