mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Move embedder, middleware, and head parameters to framework agnostic modules. (#45)
Part of #28
This commit is contained in:
committed by
Scott Leishman
parent
16b3e99f37
commit
a888226641
@@ -33,19 +33,6 @@ def normalized_columns_initializer(std=1.0):
|
||||
return _initializer
|
||||
|
||||
|
||||
class HeadParameters(NetworkComponentParameters):
|
||||
def __init__(self, parameterized_class: Type['Head'], activation_function: str = 'relu', name: str= 'head',
|
||||
num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0,
|
||||
loss_weight: float=1.0, dense_layer=Dense):
|
||||
super().__init__(dense_layer=dense_layer)
|
||||
self.activation_function = activation_function
|
||||
self.name = name
|
||||
self.num_output_head_copies = num_output_head_copies
|
||||
self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor
|
||||
self.loss_weight = loss_weight
|
||||
self.parameterized_class_name = parameterized_class.__name__
|
||||
|
||||
|
||||
class Head(object):
|
||||
"""
|
||||
A head is the final part of the network. It takes the embedding from the middleware embedder and passes it through
|
||||
@@ -74,6 +61,8 @@ class Head(object):
|
||||
self.return_type = None
|
||||
self.activation_function = activation_function
|
||||
self.dense_layer = dense_layer
|
||||
if self.dense_layer is None:
|
||||
self.dense_layer = Dense
|
||||
|
||||
def __call__(self, input_layer):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user