mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
bug-fix for l2_regularization not in use (#230)
* bug-fix for l2_regularization not in use * removing not in use TF REGULARIZATION_LOSSES collection
This commit is contained in:
@@ -68,9 +68,8 @@ class PPOHead(Head):
|
||||
if self.use_kl_regularization:
|
||||
# no clipping => use kl regularization
|
||||
self.weighted_kl_divergence = tf.multiply(self.kl_coefficient, self.kl_divergence)
|
||||
self.regularizations = self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \
|
||||
tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff))
|
||||
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
|
||||
self.regularizations += [self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \
|
||||
tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff))]
|
||||
|
||||
# calculate surrogate loss
|
||||
self.advantages = tf.placeholder(tf.float32, [None], name="advantages")
|
||||
@@ -93,8 +92,7 @@ class PPOHead(Head):
|
||||
# add entropy regularization
|
||||
if self.beta:
|
||||
self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
|
||||
self.regularizations = -tf.multiply(self.beta, self.entropy, name='entropy_regularization')
|
||||
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
|
||||
self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')]
|
||||
|
||||
self.loss = self.surrogate_loss
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
Reference in New Issue
Block a user