mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
bug-fix for l2_regularization not in use (#230)
* bug-fix for l2_regularization not in use * removing not in use TF REGULARIZATION_LOSSES collection
This commit is contained in:
@@ -222,7 +222,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
||||
'head_idx': head_idx, 'is_local': self.network_is_local})
|
||||
|
||||
def get_model(self):
|
||||
def get_model(self) -> List:
|
||||
# validate the configuration
|
||||
if len(self.network_parameters.input_embedders_parameters) == 0:
|
||||
raise ValueError("At least one input type should be defined")
|
||||
@@ -338,9 +338,18 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
|
||||
head_count += 1
|
||||
|
||||
# model weights
|
||||
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)
|
||||
|
||||
# Losses
|
||||
self.losses = tf.losses.get_losses(self.full_name)
|
||||
self.losses += tf.losses.get_regularization_losses(self.full_name)
|
||||
|
||||
# L2 regularization
|
||||
if self.network_parameters.l2_regularization != 0:
|
||||
self.l2_regularization = tf.add_n([tf.nn.l2_loss(v) for v in self.weights]) \
|
||||
* self.network_parameters.l2_regularization
|
||||
self.losses += self.l2_regularization
|
||||
|
||||
self.total_loss = tf.reduce_sum(self.losses)
|
||||
# tf.summary.scalar('total_loss', self.total_loss)
|
||||
|
||||
@@ -386,6 +395,8 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
else:
|
||||
raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type))
|
||||
|
||||
return self.weights
|
||||
|
||||
def __str__(self):
|
||||
result = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user