mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
bug-fix for l2_regularization not in use (#230)
* bug-fix for l2_regularization not in use * removing not in use TF REGULARIZATION_LOSSES collection
This commit is contained in:
@@ -102,10 +102,7 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
self.global_step = tf.train.get_or_create_global_step()
|
self.global_step = tf.train.get_or_create_global_step()
|
||||||
|
|
||||||
# build the network
|
# build the network
|
||||||
self.get_model()
|
self.weights = self.get_model()
|
||||||
|
|
||||||
# model weights
|
|
||||||
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)
|
|
||||||
|
|
||||||
# create the placeholder for the assigning gradients and some tensorboard summaries for the weights
|
# create the placeholder for the assigning gradients and some tensorboard summaries for the weights
|
||||||
for idx, var in enumerate(self.weights):
|
for idx, var in enumerate(self.weights):
|
||||||
@@ -125,12 +122,6 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
# gradients ops
|
# gradients ops
|
||||||
self._create_gradient_ops()
|
self._create_gradient_ops()
|
||||||
|
|
||||||
# L2 regularization
|
|
||||||
if self.network_parameters.l2_regularization != 0:
|
|
||||||
self.l2_regularization = [tf.add_n([tf.nn.l2_loss(v) for v in self.weights])
|
|
||||||
* self.network_parameters.l2_regularization]
|
|
||||||
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.l2_regularization)
|
|
||||||
|
|
||||||
self.inc_step = self.global_step.assign_add(1)
|
self.inc_step = self.global_step.assign_add(1)
|
||||||
|
|
||||||
# reset LSTM hidden cells
|
# reset LSTM hidden cells
|
||||||
@@ -150,11 +141,13 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
# set the fetches for training
|
# set the fetches for training
|
||||||
self._set_initial_fetch_list()
|
self._set_initial_fetch_list()
|
||||||
|
|
||||||
def get_model(self) -> None:
|
def get_model(self) -> List:
|
||||||
"""
|
"""
|
||||||
Constructs the model using `network_parameters` and sets `input_embedders`, `middleware`,
|
Constructs the model using `network_parameters` and sets `input_embedders`, `middleware`,
|
||||||
`output_heads`, `outputs`, `losses`, `total_loss`, `adaptive_learning_rate_scheme`,
|
`output_heads`, `outputs`, `losses`, `total_loss`, `adaptive_learning_rate_scheme`,
|
||||||
`current_learning_rate`, and `optimizer`
|
`current_learning_rate`, and `optimizer`.
|
||||||
|
|
||||||
|
:return: A list of the model's weights
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
||||||
'head_idx': head_idx, 'is_local': self.network_is_local})
|
'head_idx': head_idx, 'is_local': self.network_is_local})
|
||||||
|
|
||||||
def get_model(self):
|
def get_model(self) -> List:
|
||||||
# validate the configuration
|
# validate the configuration
|
||||||
if len(self.network_parameters.input_embedders_parameters) == 0:
|
if len(self.network_parameters.input_embedders_parameters) == 0:
|
||||||
raise ValueError("At least one input type should be defined")
|
raise ValueError("At least one input type should be defined")
|
||||||
@@ -338,9 +338,18 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
|
|
||||||
head_count += 1
|
head_count += 1
|
||||||
|
|
||||||
|
# model weights
|
||||||
|
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)
|
||||||
|
|
||||||
# Losses
|
# Losses
|
||||||
self.losses = tf.losses.get_losses(self.full_name)
|
self.losses = tf.losses.get_losses(self.full_name)
|
||||||
self.losses += tf.losses.get_regularization_losses(self.full_name)
|
|
||||||
|
# L2 regularization
|
||||||
|
if self.network_parameters.l2_regularization != 0:
|
||||||
|
self.l2_regularization = tf.add_n([tf.nn.l2_loss(v) for v in self.weights]) \
|
||||||
|
* self.network_parameters.l2_regularization
|
||||||
|
self.losses += self.l2_regularization
|
||||||
|
|
||||||
self.total_loss = tf.reduce_sum(self.losses)
|
self.total_loss = tf.reduce_sum(self.losses)
|
||||||
# tf.summary.scalar('total_loss', self.total_loss)
|
# tf.summary.scalar('total_loss', self.total_loss)
|
||||||
|
|
||||||
@@ -386,6 +395,8 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
else:
|
else:
|
||||||
raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type))
|
raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type))
|
||||||
|
|
||||||
|
return self.weights
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ class ACERPolicyHead(Head):
|
|||||||
if self.beta:
|
if self.beta:
|
||||||
self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
|
self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
|
||||||
self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')]
|
self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')]
|
||||||
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
|
|
||||||
|
|
||||||
# Truncated importance sampling with bias corrections
|
# Truncated importance sampling with bias corrections
|
||||||
importance_sampling_weight = tf.placeholder(tf.float32, [None, self.num_actions],
|
importance_sampling_weight = tf.placeholder(tf.float32, [None, self.num_actions],
|
||||||
|
|||||||
@@ -78,8 +78,6 @@ class PolicyHead(Head):
|
|||||||
self.entropy = tf.add_n([tf.reduce_mean(dist.entropy()) for dist in self.policy_distributions])
|
self.entropy = tf.add_n([tf.reduce_mean(dist.entropy()) for dist in self.policy_distributions])
|
||||||
self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')]
|
self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')]
|
||||||
|
|
||||||
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
|
|
||||||
|
|
||||||
# calculate loss
|
# calculate loss
|
||||||
self.action_log_probs_wrt_policy = \
|
self.action_log_probs_wrt_policy = \
|
||||||
tf.add_n([dist.log_prob(action) for dist, action in zip(self.policy_distributions, self.actions)])
|
tf.add_n([dist.log_prob(action) for dist, action in zip(self.policy_distributions, self.actions)])
|
||||||
|
|||||||
@@ -68,9 +68,8 @@ class PPOHead(Head):
|
|||||||
if self.use_kl_regularization:
|
if self.use_kl_regularization:
|
||||||
# no clipping => use kl regularization
|
# no clipping => use kl regularization
|
||||||
self.weighted_kl_divergence = tf.multiply(self.kl_coefficient, self.kl_divergence)
|
self.weighted_kl_divergence = tf.multiply(self.kl_coefficient, self.kl_divergence)
|
||||||
self.regularizations = self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \
|
self.regularizations += [self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \
|
||||||
tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff))
|
tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff))]
|
||||||
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
|
|
||||||
|
|
||||||
# calculate surrogate loss
|
# calculate surrogate loss
|
||||||
self.advantages = tf.placeholder(tf.float32, [None], name="advantages")
|
self.advantages = tf.placeholder(tf.float32, [None], name="advantages")
|
||||||
@@ -93,8 +92,7 @@ class PPOHead(Head):
|
|||||||
# add entropy regularization
|
# add entropy regularization
|
||||||
if self.beta:
|
if self.beta:
|
||||||
self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
|
self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
|
||||||
self.regularizations = -tf.multiply(self.beta, self.entropy, name='entropy_regularization')
|
self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')]
|
||||||
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
|
|
||||||
|
|
||||||
self.loss = self.surrogate_loss
|
self.loss = self.surrogate_loss
|
||||||
tf.losses.add_loss(self.loss)
|
tf.losses.add_loss(self.loss)
|
||||||
|
|||||||
Reference in New Issue
Block a user