1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

appending csv's from logger instead of rewriting them

This commit is contained in:
Itai Caspi
2018-02-12 01:33:43 +02:00
committed by Itai Caspi
parent 569ca39ce6
commit ba96e585d2
6 changed files with 33 additions and 14 deletions

View File

@@ -28,9 +28,13 @@ class ActorCriticAgent(PolicyOptimizationAgent):
self.action_advantages = Signal('Advantages')
self.state_values = Signal('Values')
self.unclipped_grads = Signal('Grads (unclipped)')
self.value_loss = Signal('Value Loss')
self.policy_loss = Signal('Policy Loss')
self.signals.append(self.action_advantages)
self.signals.append(self.state_values)
self.signals.append(self.unclipped_grads)
self.signals.append(self.value_loss)
self.signals.append(self.policy_loss)
# Discounting function used to calculate discounted returns.
def discount(self, x, gamma):
@@ -104,8 +108,8 @@ class ActorCriticAgent(PolicyOptimizationAgent):
total_loss, losses, unclipped_grads = result[:3]
self.action_advantages.add_sample(action_advantages)
self.unclipped_grads.add_sample(unclipped_grads)
logger.create_signal_value('Value Loss', losses[0])
logger.create_signal_value('Policy Loss', losses[1])
self.value_loss.add_sample(losses[0])
self.policy_loss.add_sample(losses[1])
return total_loss

View File

@@ -164,10 +164,11 @@ class Agent(object):
logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
logger.create_signal_value('Total steps', self.total_steps_counter)
logger.create_signal_value("Epsilon", self.exploration_policy.get_control_param())
if phase == RunPhase.TRAIN:
logger.create_signal_value("Training Reward", self.total_reward_in_current_episode)
elif phase == RunPhase.TEST:
logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode)
logger.create_signal_value("Training Reward", self.total_reward_in_current_episode
if phase == RunPhase.TRAIN else np.nan)
logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode
if phase == RunPhase.TEST else np.nan)
logger.create_signal_value('Update Target Network', 0, overwrite=False)
logger.update_wall_clock_time(self.current_episode)
for signal in self.signals:
@@ -177,7 +178,8 @@ class Agent(object):
logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
# dump
if self.current_episode % self.tp.visualization.dump_signals_to_csv_every_x_episodes == 0:
if self.current_episode % self.tp.visualization.dump_signals_to_csv_every_x_episodes == 0 \
and self.current_episode > 0:
logger.dump_output_csv()
def reset_game(self, do_not_reset_env=False):

View File

@@ -28,8 +28,10 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
self.last_gradient_update_step_idx = 0
self.q_values = Signal('Q Values')
self.unclipped_grads = Signal('Grads (unclipped)')
self.value_loss = Signal('Value Loss')
self.signals.append(self.q_values)
self.signals.append(self.unclipped_grads)
self.signals.append(self.value_loss)
def learn_from_batch(self, batch):
# batch contains a list of episodes to learn from
@@ -69,7 +71,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
# logging
total_loss, losses, unclipped_grads = result[:3]
self.unclipped_grads.add_sample(unclipped_grads)
logger.create_signal_value('Value Loss', losses[0])
self.value_loss.add_sample(losses[0])
return total_loss

View File

@@ -30,7 +30,10 @@ from utils import *
class PolicyGradientsAgent(PolicyOptimizationAgent):
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
PolicyOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
self.returns_mean = Signal('Returns Mean')
self.returns_variance = Signal('Returns Variance')
self.signals.append(self.returns_mean)
self.signals.append(self.returns_variance)
self.last_gradient_update_step_idx = 0
def learn_from_batch(self, batch):
@@ -58,8 +61,8 @@ class PolicyGradientsAgent(PolicyOptimizationAgent):
if not self.env.discrete_controls and len(actions.shape) < 2:
actions = np.expand_dims(actions, -1)
logger.create_signal_value('Returns Variance', np.std(total_returns), self.task_id)
logger.create_signal_value('Returns Mean', np.mean(total_returns), self.task_id)
self.returns_mean.add_sample(np.mean(total_returns))
self.returns_variance.add_sample(np.std(total_returns))
result = self.main_network.online_network.accumulate_gradients([current_states, actions], targets)
total_loss = result[0]