mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
bug-fix for clipped_ppo not logging several signals + small cleanup
This commit is contained in:
@@ -51,6 +51,9 @@ class ClippedPPONetworkParameters(NetworkParameters):
|
||||
self.use_separate_networks_per_head = True
|
||||
self.async_training = False
|
||||
self.l2_regularization = 0
|
||||
|
||||
# The target network is used in order to freeze the old policy, while making updates to the new one
|
||||
# in train_network()
|
||||
self.create_target_network = True
|
||||
self.shared_optimizer = True
|
||||
self.scale_down_gradients_by_number_of_workers_for_sync_training = True
|
||||
@@ -99,7 +102,6 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
self.likelihood_ratio = self.register_signal('Likelihood Ratio')
|
||||
self.clipped_likelihood_ratio = self.register_signal('Clipped Likelihood Ratio')
|
||||
|
||||
|
||||
def set_session(self, sess):
|
||||
super().set_session(sess)
|
||||
if self.ap.algorithm.normalization_stats is not None:
|
||||
@@ -219,6 +221,7 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
|
||||
self.value_loss.add_sample(batch_results['losses'][0])
|
||||
self.policy_loss.add_sample(batch_results['losses'][1])
|
||||
self.loss.add_sample(batch_results['total_loss'])
|
||||
|
||||
if self.ap.network_wrappers['main'].learning_rate_decay_rate != 0:
|
||||
curr_learning_rate = self.networks['main'].online_network.get_variable_value(
|
||||
@@ -268,7 +271,8 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
|
||||
self.post_training_commands()
|
||||
self.training_iteration += 1
|
||||
# self.update_log() # should be done in order to update the data that has been accumulated * while not playing *
|
||||
# should be done in order to update the data that has been accumulated * while not playing *
|
||||
self.update_log()
|
||||
return None
|
||||
|
||||
def run_pre_network_filter_for_inference(self, state: StateType):
|
||||
|
||||
@@ -61,8 +61,6 @@ class Head(object):
|
||||
self.regularizations = []
|
||||
self.loss_weight = tf.Variable([float(w) for w in force_list(loss_weight)],
|
||||
trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
self.loss_weight_placeholder = tf.placeholder("float")
|
||||
self.set_loss_weight = tf.assign(self.loss_weight, self.loss_weight_placeholder)
|
||||
self.target = []
|
||||
self.importance_weight = []
|
||||
self.input = []
|
||||
|
||||
Reference in New Issue
Block a user