mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
appending csv's from logger instead of rewriting them
This commit is contained in:
@@ -28,9 +28,13 @@ class ActorCriticAgent(PolicyOptimizationAgent):
|
||||
self.action_advantages = Signal('Advantages')
|
||||
self.state_values = Signal('Values')
|
||||
self.unclipped_grads = Signal('Grads (unclipped)')
|
||||
self.value_loss = Signal('Value Loss')
|
||||
self.policy_loss = Signal('Policy Loss')
|
||||
self.signals.append(self.action_advantages)
|
||||
self.signals.append(self.state_values)
|
||||
self.signals.append(self.unclipped_grads)
|
||||
self.signals.append(self.value_loss)
|
||||
self.signals.append(self.policy_loss)
|
||||
|
||||
# Discounting function used to calculate discounted returns.
|
||||
def discount(self, x, gamma):
|
||||
@@ -104,8 +108,8 @@ class ActorCriticAgent(PolicyOptimizationAgent):
|
||||
total_loss, losses, unclipped_grads = result[:3]
|
||||
self.action_advantages.add_sample(action_advantages)
|
||||
self.unclipped_grads.add_sample(unclipped_grads)
|
||||
logger.create_signal_value('Value Loss', losses[0])
|
||||
logger.create_signal_value('Policy Loss', losses[1])
|
||||
self.value_loss.add_sample(losses[0])
|
||||
self.policy_loss.add_sample(losses[1])
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user