mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
bug-fix for l2_regularization not in use (#230)
* bug-fix for l2_regularization not in use * removing not in use TF REGULARIZATION_LOSSES collection
This commit is contained in:
@@ -102,10 +102,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
# build the network
|
||||
self.get_model()
|
||||
|
||||
# model weights
|
||||
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)
|
||||
self.weights = self.get_model()
|
||||
|
||||
# create the placeholder for the assigning gradients and some tensorboard summaries for the weights
|
||||
for idx, var in enumerate(self.weights):
|
||||
@@ -125,12 +122,6 @@ class TensorFlowArchitecture(Architecture):
|
||||
# gradients ops
|
||||
self._create_gradient_ops()
|
||||
|
||||
# L2 regularization
|
||||
if self.network_parameters.l2_regularization != 0:
|
||||
self.l2_regularization = [tf.add_n([tf.nn.l2_loss(v) for v in self.weights])
|
||||
* self.network_parameters.l2_regularization]
|
||||
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.l2_regularization)
|
||||
|
||||
self.inc_step = self.global_step.assign_add(1)
|
||||
|
||||
# reset LSTM hidden cells
|
||||
@@ -150,11 +141,13 @@ class TensorFlowArchitecture(Architecture):
|
||||
# set the fetches for training
|
||||
self._set_initial_fetch_list()
|
||||
|
||||
def get_model(self) -> None:
|
||||
def get_model(self) -> List:
|
||||
"""
|
||||
Constructs the model using `network_parameters` and sets `input_embedders`, `middleware`,
|
||||
`output_heads`, `outputs`, `losses`, `total_loss`, `adaptive_learning_rate_scheme`,
|
||||
`current_learning_rate`, and `optimizer`
|
||||
`current_learning_rate`, and `optimizer`.
|
||||
|
||||
:return: A list of the model's weights
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user