1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

bug-fix for clipped_ppo not logging several signals + small cleanup

This commit is contained in:
Gal Leibovich
2018-10-02 14:22:37 +03:00
parent 73cc6e39d0
commit 72ea933384
2 changed files with 6 additions and 4 deletions

View File

@@ -51,6 +51,9 @@ class ClippedPPONetworkParameters(NetworkParameters):
self.use_separate_networks_per_head = True
self.async_training = False
self.l2_regularization = 0
# The target network is used in order to freeze the old policy, while making updates to the new one
# in train_network()
self.create_target_network = True
self.shared_optimizer = True
self.scale_down_gradients_by_number_of_workers_for_sync_training = True
@@ -99,7 +102,6 @@ class ClippedPPOAgent(ActorCriticAgent):
self.likelihood_ratio = self.register_signal('Likelihood Ratio')
self.clipped_likelihood_ratio = self.register_signal('Clipped Likelihood Ratio')
def set_session(self, sess):
super().set_session(sess)
if self.ap.algorithm.normalization_stats is not None:
@@ -219,6 +221,7 @@ class ClippedPPOAgent(ActorCriticAgent):
self.value_loss.add_sample(batch_results['losses'][0])
self.policy_loss.add_sample(batch_results['losses'][1])
self.loss.add_sample(batch_results['total_loss'])
if self.ap.network_wrappers['main'].learning_rate_decay_rate != 0:
curr_learning_rate = self.networks['main'].online_network.get_variable_value(
@@ -268,7 +271,8 @@ class ClippedPPOAgent(ActorCriticAgent):
self.post_training_commands()
self.training_iteration += 1
# self.update_log() # should be done in order to update the data that has been accumulated * while not playing *
# should be done in order to update the data that has been accumulated * while not playing *
self.update_log()
return None
def run_pre_network_filter_for_inference(self, state: StateType):