mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
appending csv's from logger instead of rewriting them
This commit is contained in:
@@ -30,7 +30,10 @@ from utils import *
|
||||
class PolicyGradientsAgent(PolicyOptimizationAgent):
|
||||
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
||||
PolicyOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
||||
|
||||
self.returns_mean = Signal('Returns Mean')
|
||||
self.returns_variance = Signal('Returns Variance')
|
||||
self.signals.append(self.returns_mean)
|
||||
self.signals.append(self.returns_variance)
|
||||
self.last_gradient_update_step_idx = 0
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
@@ -58,8 +61,8 @@ class PolicyGradientsAgent(PolicyOptimizationAgent):
|
||||
if not self.env.discrete_controls and len(actions.shape) < 2:
|
||||
actions = np.expand_dims(actions, -1)
|
||||
|
||||
logger.create_signal_value('Returns Variance', np.std(total_returns), self.task_id)
|
||||
logger.create_signal_value('Returns Mean', np.mean(total_returns), self.task_id)
|
||||
self.returns_mean.add_sample(np.mean(total_returns))
|
||||
self.returns_variance.add_sample(np.std(total_returns))
|
||||
|
||||
result = self.main_network.online_network.accumulate_gradients([current_states, actions], targets)
|
||||
total_loss = result[0]
|
||||
|
||||
Reference in New Issue
Block a user