diff --git a/rl_coach/architectures/tensorflow_components/architecture.py b/rl_coach/architectures/tensorflow_components/architecture.py index 4af1622..648381f 100644 --- a/rl_coach/architectures/tensorflow_components/architecture.py +++ b/rl_coach/architectures/tensorflow_components/architecture.py @@ -102,10 +102,7 @@ class TensorFlowArchitecture(Architecture): self.global_step = tf.train.get_or_create_global_step() # build the network - self.get_model() - - # model weights - self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name) + self.weights = self.get_model() # create the placeholder for the assigning gradients and some tensorboard summaries for the weights for idx, var in enumerate(self.weights): @@ -125,12 +122,6 @@ class TensorFlowArchitecture(Architecture): # gradients ops self._create_gradient_ops() - # L2 regularization - if self.network_parameters.l2_regularization != 0: - self.l2_regularization = [tf.add_n([tf.nn.l2_loss(v) for v in self.weights]) - * self.network_parameters.l2_regularization] - tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.l2_regularization) - self.inc_step = self.global_step.assign_add(1) # reset LSTM hidden cells @@ -150,11 +141,13 @@ class TensorFlowArchitecture(Architecture): # set the fetches for training self._set_initial_fetch_list() - def get_model(self) -> None: + def get_model(self) -> List: """ Constructs the model using `network_parameters` and sets `input_embedders`, `middleware`, `output_heads`, `outputs`, `losses`, `total_loss`, `adaptive_learning_rate_scheme`, - `current_learning_rate`, and `optimizer` + `current_learning_rate`, and `optimizer`. + + :return: A list of the model's weights """ raise NotImplementedError diff --git a/rl_coach/architectures/tensorflow_components/general_network.py b/rl_coach/architectures/tensorflow_components/general_network.py index 1a657b3..0103659 100644 --- a/rl_coach/architectures/tensorflow_components/general_network.py +++ b/rl_coach/architectures/tensorflow_components/general_network.py @@ -222,7 +222,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): 'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name, 'head_idx': head_idx, 'is_local': self.network_is_local}) - def get_model(self): + def get_model(self) -> List: # validate the configuration if len(self.network_parameters.input_embedders_parameters) == 0: raise ValueError("At least one input type should be defined") @@ -338,9 +338,18 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): head_count += 1 + # model weights + self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name) + # Losses self.losses = tf.losses.get_losses(self.full_name) - self.losses += tf.losses.get_regularization_losses(self.full_name) + + # L2 regularization + if self.network_parameters.l2_regularization != 0: + self.l2_regularization = tf.add_n([tf.nn.l2_loss(v) for v in self.weights]) \ + * self.network_parameters.l2_regularization + self.losses += self.l2_regularization + self.total_loss = tf.reduce_sum(self.losses) # tf.summary.scalar('total_loss', self.total_loss) @@ -386,6 +395,8 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): else: raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type)) + return self.weights + def __str__(self): result = [] diff --git a/rl_coach/architectures/tensorflow_components/heads/acer_policy_head.py b/rl_coach/architectures/tensorflow_components/heads/acer_policy_head.py index 567cfb5..d31fa3d 100644 --- a/rl_coach/architectures/tensorflow_components/heads/acer_policy_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/acer_policy_head.py @@ -56,7 +56,6 @@ class ACERPolicyHead(Head): if self.beta: self.entropy = tf.reduce_mean(self.policy_distribution.entropy()) self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')] - tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations) # Truncated importance sampling with bias corrections importance_sampling_weight = tf.placeholder(tf.float32, [None, self.num_actions], diff --git a/rl_coach/architectures/tensorflow_components/heads/policy_head.py b/rl_coach/architectures/tensorflow_components/heads/policy_head.py index 99c9958..540bd1a 100644 --- a/rl_coach/architectures/tensorflow_components/heads/policy_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/policy_head.py @@ -78,8 +78,6 @@ class PolicyHead(Head): self.entropy = tf.add_n([tf.reduce_mean(dist.entropy()) for dist in self.policy_distributions]) self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')] - tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations) - # calculate loss self.action_log_probs_wrt_policy = \ tf.add_n([dist.log_prob(action) for dist, action in zip(self.policy_distributions, self.actions)]) diff --git a/rl_coach/architectures/tensorflow_components/heads/ppo_head.py b/rl_coach/architectures/tensorflow_components/heads/ppo_head.py index 2dacaea..63f95a3 100644 --- a/rl_coach/architectures/tensorflow_components/heads/ppo_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/ppo_head.py @@ -68,9 +68,8 @@ class PPOHead(Head): if self.use_kl_regularization: # no clipping => use kl regularization self.weighted_kl_divergence = tf.multiply(self.kl_coefficient, self.kl_divergence) - self.regularizations = self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \ - tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff)) - tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations) + self.regularizations += [self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \ + tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff))] # calculate surrogate loss self.advantages = tf.placeholder(tf.float32, [None], name="advantages") @@ -93,8 +92,7 @@ class PPOHead(Head): # add entropy regularization if self.beta: self.entropy = tf.reduce_mean(self.policy_distribution.entropy()) - self.regularizations = -tf.multiply(self.beta, self.entropy, name='entropy_regularization') - tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations) + self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')] self.loss = self.surrogate_loss tf.losses.add_loss(self.loss)