1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

appending csv's from logger instead of rewriting them

This commit is contained in:
Itai Caspi
2018-02-12 01:33:43 +02:00
committed by Itai Caspi
parent 569ca39ce6
commit ba96e585d2
6 changed files with 33 additions and 14 deletions

View File

@@ -28,9 +28,13 @@ class ActorCriticAgent(PolicyOptimizationAgent):
self.action_advantages = Signal('Advantages')
self.state_values = Signal('Values')
self.unclipped_grads = Signal('Grads (unclipped)')
self.value_loss = Signal('Value Loss')
self.policy_loss = Signal('Policy Loss')
self.signals.append(self.action_advantages)
self.signals.append(self.state_values)
self.signals.append(self.unclipped_grads)
self.signals.append(self.value_loss)
self.signals.append(self.policy_loss)
# Discounting function used to calculate discounted returns.
def discount(self, x, gamma):
@@ -104,8 +108,8 @@ class ActorCriticAgent(PolicyOptimizationAgent):
total_loss, losses, unclipped_grads = result[:3]
self.action_advantages.add_sample(action_advantages)
self.unclipped_grads.add_sample(unclipped_grads)
logger.create_signal_value('Value Loss', losses[0])
logger.create_signal_value('Policy Loss', losses[1])
self.value_loss.add_sample(losses[0])
self.policy_loss.add_sample(losses[1])
return total_loss