1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

bug-fix for clipped_ppo not logging several signals + small cleanup

This commit is contained in:
Gal Leibovich
2018-10-02 14:22:37 +03:00
parent 73cc6e39d0
commit 72ea933384
2 changed files with 6 additions and 4 deletions

View File

@@ -51,6 +51,9 @@ class ClippedPPONetworkParameters(NetworkParameters):
self.use_separate_networks_per_head = True self.use_separate_networks_per_head = True
self.async_training = False self.async_training = False
self.l2_regularization = 0 self.l2_regularization = 0
# The target network is used in order to freeze the old policy, while making updates to the new one
# in train_network()
self.create_target_network = True self.create_target_network = True
self.shared_optimizer = True self.shared_optimizer = True
self.scale_down_gradients_by_number_of_workers_for_sync_training = True self.scale_down_gradients_by_number_of_workers_for_sync_training = True
@@ -99,7 +102,6 @@ class ClippedPPOAgent(ActorCriticAgent):
self.likelihood_ratio = self.register_signal('Likelihood Ratio') self.likelihood_ratio = self.register_signal('Likelihood Ratio')
self.clipped_likelihood_ratio = self.register_signal('Clipped Likelihood Ratio') self.clipped_likelihood_ratio = self.register_signal('Clipped Likelihood Ratio')
def set_session(self, sess): def set_session(self, sess):
super().set_session(sess) super().set_session(sess)
if self.ap.algorithm.normalization_stats is not None: if self.ap.algorithm.normalization_stats is not None:
@@ -219,6 +221,7 @@ class ClippedPPOAgent(ActorCriticAgent):
self.value_loss.add_sample(batch_results['losses'][0]) self.value_loss.add_sample(batch_results['losses'][0])
self.policy_loss.add_sample(batch_results['losses'][1]) self.policy_loss.add_sample(batch_results['losses'][1])
self.loss.add_sample(batch_results['total_loss'])
if self.ap.network_wrappers['main'].learning_rate_decay_rate != 0: if self.ap.network_wrappers['main'].learning_rate_decay_rate != 0:
curr_learning_rate = self.networks['main'].online_network.get_variable_value( curr_learning_rate = self.networks['main'].online_network.get_variable_value(
@@ -268,7 +271,8 @@ class ClippedPPOAgent(ActorCriticAgent):
self.post_training_commands() self.post_training_commands()
self.training_iteration += 1 self.training_iteration += 1
# self.update_log() # should be done in order to update the data that has been accumulated * while not playing * # should be done in order to update the data that has been accumulated * while not playing *
self.update_log()
return None return None
def run_pre_network_filter_for_inference(self, state: StateType): def run_pre_network_filter_for_inference(self, state: StateType):

View File

@@ -61,8 +61,6 @@ class Head(object):
self.regularizations = [] self.regularizations = []
self.loss_weight = tf.Variable([float(w) for w in force_list(loss_weight)], self.loss_weight = tf.Variable([float(w) for w in force_list(loss_weight)],
trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
self.loss_weight_placeholder = tf.placeholder("float")
self.set_loss_weight = tf.assign(self.loss_weight, self.loss_weight_placeholder)
self.target = [] self.target = []
self.importance_weight = [] self.importance_weight = []
self.input = [] self.input = []