1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

bug-fix for l2_regularization not in use (#230)

* bug-fix for l2_regularization not in use
* removing not in use TF REGULARIZATION_LOSSES collection
This commit is contained in:
Gal Leibovich
2019-03-03 15:11:06 +02:00
committed by Gal Novik
parent 10220be9be
commit 9a895a1ac7
5 changed files with 21 additions and 22 deletions

View File

@@ -102,10 +102,7 @@ class TensorFlowArchitecture(Architecture):
self.global_step = tf.train.get_or_create_global_step()
# build the network
self.get_model()
# model weights
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)
self.weights = self.get_model()
# create the placeholder for the assigning gradients and some tensorboard summaries for the weights
for idx, var in enumerate(self.weights):
@@ -125,12 +122,6 @@ class TensorFlowArchitecture(Architecture):
# gradients ops
self._create_gradient_ops()
# L2 regularization
if self.network_parameters.l2_regularization != 0:
self.l2_regularization = [tf.add_n([tf.nn.l2_loss(v) for v in self.weights])
* self.network_parameters.l2_regularization]
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.l2_regularization)
self.inc_step = self.global_step.assign_add(1)
# reset LSTM hidden cells
@@ -150,11 +141,13 @@ class TensorFlowArchitecture(Architecture):
# set the fetches for training
self._set_initial_fetch_list()
def get_model(self) -> None:
def get_model(self) -> List:
"""
Constructs the model using `network_parameters` and sets `input_embedders`, `middleware`,
`output_heads`, `outputs`, `losses`, `total_loss`, `adaptive_learning_rate_scheme`,
`current_learning_rate`, and `optimizer`
`current_learning_rate`, and `optimizer`.
:return: A list of the model's weights
"""
raise NotImplementedError