mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
bug-fix for clipped_ppo not logging several signals + small cleanup
This commit is contained in:
@@ -51,6 +51,9 @@ class ClippedPPONetworkParameters(NetworkParameters):
|
|||||||
self.use_separate_networks_per_head = True
|
self.use_separate_networks_per_head = True
|
||||||
self.async_training = False
|
self.async_training = False
|
||||||
self.l2_regularization = 0
|
self.l2_regularization = 0
|
||||||
|
|
||||||
|
# The target network is used in order to freeze the old policy, while making updates to the new one
|
||||||
|
# in train_network()
|
||||||
self.create_target_network = True
|
self.create_target_network = True
|
||||||
self.shared_optimizer = True
|
self.shared_optimizer = True
|
||||||
self.scale_down_gradients_by_number_of_workers_for_sync_training = True
|
self.scale_down_gradients_by_number_of_workers_for_sync_training = True
|
||||||
@@ -99,7 +102,6 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
self.likelihood_ratio = self.register_signal('Likelihood Ratio')
|
self.likelihood_ratio = self.register_signal('Likelihood Ratio')
|
||||||
self.clipped_likelihood_ratio = self.register_signal('Clipped Likelihood Ratio')
|
self.clipped_likelihood_ratio = self.register_signal('Clipped Likelihood Ratio')
|
||||||
|
|
||||||
|
|
||||||
def set_session(self, sess):
|
def set_session(self, sess):
|
||||||
super().set_session(sess)
|
super().set_session(sess)
|
||||||
if self.ap.algorithm.normalization_stats is not None:
|
if self.ap.algorithm.normalization_stats is not None:
|
||||||
@@ -219,6 +221,7 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
|
|
||||||
self.value_loss.add_sample(batch_results['losses'][0])
|
self.value_loss.add_sample(batch_results['losses'][0])
|
||||||
self.policy_loss.add_sample(batch_results['losses'][1])
|
self.policy_loss.add_sample(batch_results['losses'][1])
|
||||||
|
self.loss.add_sample(batch_results['total_loss'])
|
||||||
|
|
||||||
if self.ap.network_wrappers['main'].learning_rate_decay_rate != 0:
|
if self.ap.network_wrappers['main'].learning_rate_decay_rate != 0:
|
||||||
curr_learning_rate = self.networks['main'].online_network.get_variable_value(
|
curr_learning_rate = self.networks['main'].online_network.get_variable_value(
|
||||||
@@ -268,7 +271,8 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
|
|
||||||
self.post_training_commands()
|
self.post_training_commands()
|
||||||
self.training_iteration += 1
|
self.training_iteration += 1
|
||||||
# self.update_log() # should be done in order to update the data that has been accumulated * while not playing *
|
# should be done in order to update the data that has been accumulated * while not playing *
|
||||||
|
self.update_log()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def run_pre_network_filter_for_inference(self, state: StateType):
|
def run_pre_network_filter_for_inference(self, state: StateType):
|
||||||
|
|||||||
@@ -61,8 +61,6 @@ class Head(object):
|
|||||||
self.regularizations = []
|
self.regularizations = []
|
||||||
self.loss_weight = tf.Variable([float(w) for w in force_list(loss_weight)],
|
self.loss_weight = tf.Variable([float(w) for w in force_list(loss_weight)],
|
||||||
trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
self.loss_weight_placeholder = tf.placeholder("float")
|
|
||||||
self.set_loss_weight = tf.assign(self.loss_weight, self.loss_weight_placeholder)
|
|
||||||
self.target = []
|
self.target = []
|
||||||
self.importance_weight = []
|
self.importance_weight = []
|
||||||
self.input = []
|
self.input = []
|
||||||
|
|||||||
Reference in New Issue
Block a user