mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
appending csv's from logger instead of rewriting them
This commit is contained in:
@@ -28,9 +28,13 @@ class ActorCriticAgent(PolicyOptimizationAgent):
|
|||||||
self.action_advantages = Signal('Advantages')
|
self.action_advantages = Signal('Advantages')
|
||||||
self.state_values = Signal('Values')
|
self.state_values = Signal('Values')
|
||||||
self.unclipped_grads = Signal('Grads (unclipped)')
|
self.unclipped_grads = Signal('Grads (unclipped)')
|
||||||
|
self.value_loss = Signal('Value Loss')
|
||||||
|
self.policy_loss = Signal('Policy Loss')
|
||||||
self.signals.append(self.action_advantages)
|
self.signals.append(self.action_advantages)
|
||||||
self.signals.append(self.state_values)
|
self.signals.append(self.state_values)
|
||||||
self.signals.append(self.unclipped_grads)
|
self.signals.append(self.unclipped_grads)
|
||||||
|
self.signals.append(self.value_loss)
|
||||||
|
self.signals.append(self.policy_loss)
|
||||||
|
|
||||||
# Discounting function used to calculate discounted returns.
|
# Discounting function used to calculate discounted returns.
|
||||||
def discount(self, x, gamma):
|
def discount(self, x, gamma):
|
||||||
@@ -104,8 +108,8 @@ class ActorCriticAgent(PolicyOptimizationAgent):
|
|||||||
total_loss, losses, unclipped_grads = result[:3]
|
total_loss, losses, unclipped_grads = result[:3]
|
||||||
self.action_advantages.add_sample(action_advantages)
|
self.action_advantages.add_sample(action_advantages)
|
||||||
self.unclipped_grads.add_sample(unclipped_grads)
|
self.unclipped_grads.add_sample(unclipped_grads)
|
||||||
logger.create_signal_value('Value Loss', losses[0])
|
self.value_loss.add_sample(losses[0])
|
||||||
logger.create_signal_value('Policy Loss', losses[1])
|
self.policy_loss.add_sample(losses[1])
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|||||||
@@ -164,10 +164,11 @@ class Agent(object):
|
|||||||
logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
|
logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
|
||||||
logger.create_signal_value('Total steps', self.total_steps_counter)
|
logger.create_signal_value('Total steps', self.total_steps_counter)
|
||||||
logger.create_signal_value("Epsilon", self.exploration_policy.get_control_param())
|
logger.create_signal_value("Epsilon", self.exploration_policy.get_control_param())
|
||||||
if phase == RunPhase.TRAIN:
|
logger.create_signal_value("Training Reward", self.total_reward_in_current_episode
|
||||||
logger.create_signal_value("Training Reward", self.total_reward_in_current_episode)
|
if phase == RunPhase.TRAIN else np.nan)
|
||||||
elif phase == RunPhase.TEST:
|
logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode
|
||||||
logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode)
|
if phase == RunPhase.TEST else np.nan)
|
||||||
|
logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
||||||
logger.update_wall_clock_time(self.current_episode)
|
logger.update_wall_clock_time(self.current_episode)
|
||||||
|
|
||||||
for signal in self.signals:
|
for signal in self.signals:
|
||||||
@@ -177,7 +178,8 @@ class Agent(object):
|
|||||||
logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
|
logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
|
||||||
|
|
||||||
# dump
|
# dump
|
||||||
if self.current_episode % self.tp.visualization.dump_signals_to_csv_every_x_episodes == 0:
|
if self.current_episode % self.tp.visualization.dump_signals_to_csv_every_x_episodes == 0 \
|
||||||
|
and self.current_episode > 0:
|
||||||
logger.dump_output_csv()
|
logger.dump_output_csv()
|
||||||
|
|
||||||
def reset_game(self, do_not_reset_env=False):
|
def reset_game(self, do_not_reset_env=False):
|
||||||
|
|||||||
@@ -28,8 +28,10 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
|||||||
self.last_gradient_update_step_idx = 0
|
self.last_gradient_update_step_idx = 0
|
||||||
self.q_values = Signal('Q Values')
|
self.q_values = Signal('Q Values')
|
||||||
self.unclipped_grads = Signal('Grads (unclipped)')
|
self.unclipped_grads = Signal('Grads (unclipped)')
|
||||||
|
self.value_loss = Signal('Value Loss')
|
||||||
self.signals.append(self.q_values)
|
self.signals.append(self.q_values)
|
||||||
self.signals.append(self.unclipped_grads)
|
self.signals.append(self.unclipped_grads)
|
||||||
|
self.signals.append(self.value_loss)
|
||||||
|
|
||||||
def learn_from_batch(self, batch):
|
def learn_from_batch(self, batch):
|
||||||
# batch contains a list of episodes to learn from
|
# batch contains a list of episodes to learn from
|
||||||
@@ -69,7 +71,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
|||||||
# logging
|
# logging
|
||||||
total_loss, losses, unclipped_grads = result[:3]
|
total_loss, losses, unclipped_grads = result[:3]
|
||||||
self.unclipped_grads.add_sample(unclipped_grads)
|
self.unclipped_grads.add_sample(unclipped_grads)
|
||||||
logger.create_signal_value('Value Loss', losses[0])
|
self.value_loss.add_sample(losses[0])
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,10 @@ from utils import *
|
|||||||
class PolicyGradientsAgent(PolicyOptimizationAgent):
|
class PolicyGradientsAgent(PolicyOptimizationAgent):
|
||||||
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
||||||
PolicyOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
PolicyOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
||||||
|
self.returns_mean = Signal('Returns Mean')
|
||||||
|
self.returns_variance = Signal('Returns Variance')
|
||||||
|
self.signals.append(self.returns_mean)
|
||||||
|
self.signals.append(self.returns_variance)
|
||||||
self.last_gradient_update_step_idx = 0
|
self.last_gradient_update_step_idx = 0
|
||||||
|
|
||||||
def learn_from_batch(self, batch):
|
def learn_from_batch(self, batch):
|
||||||
@@ -58,8 +61,8 @@ class PolicyGradientsAgent(PolicyOptimizationAgent):
|
|||||||
if not self.env.discrete_controls and len(actions.shape) < 2:
|
if not self.env.discrete_controls and len(actions.shape) < 2:
|
||||||
actions = np.expand_dims(actions, -1)
|
actions = np.expand_dims(actions, -1)
|
||||||
|
|
||||||
logger.create_signal_value('Returns Variance', np.std(total_returns), self.task_id)
|
self.returns_mean.add_sample(np.mean(total_returns))
|
||||||
logger.create_signal_value('Returns Mean', np.mean(total_returns), self.task_id)
|
self.returns_variance.add_sample(np.std(total_returns))
|
||||||
|
|
||||||
result = self.main_network.online_network.accumulate_gradients([current_states, actions], targets)
|
result = self.main_network.online_network.accumulate_gradients([current_states, actions], targets)
|
||||||
total_loss = result[0]
|
total_loss = result[0]
|
||||||
|
|||||||
@@ -250,7 +250,7 @@ class VisualizationParameters(Parameters):
|
|||||||
show_saliency_maps_every_num_episodes = 1000000000
|
show_saliency_maps_every_num_episodes = 1000000000
|
||||||
print_summary = False
|
print_summary = False
|
||||||
dump_csv = True
|
dump_csv = True
|
||||||
dump_signals_to_csv_every_x_episodes = 10
|
dump_signals_to_csv_every_x_episodes = 5
|
||||||
render = False
|
render = False
|
||||||
dump_gifs = True
|
dump_gifs = True
|
||||||
max_fps_for_human_control = 10
|
max_fps_for_human_control = 10
|
||||||
|
|||||||
12
logger.py
12
logger.py
@@ -127,6 +127,7 @@ class Logger(BaseLogger):
|
|||||||
self.start_time = None
|
self.start_time = None
|
||||||
self.time = None
|
self.time = None
|
||||||
self.experiments_path = ""
|
self.experiments_path = ""
|
||||||
|
self.last_line_idx_written_to_csv = 0
|
||||||
|
|
||||||
def set_current_time(self, time):
|
def set_current_time(self, time):
|
||||||
self.time = time
|
self.time = time
|
||||||
@@ -184,16 +185,23 @@ class Logger(BaseLogger):
|
|||||||
def get_signal_value(self, time, signal_name):
|
def get_signal_value(self, time, signal_name):
|
||||||
return self.data.loc[time, signal_name]
|
return self.data.loc[time, signal_name]
|
||||||
|
|
||||||
def dump_output_csv(self):
|
def dump_output_csv(self, append=True):
|
||||||
self.data.index.name = "Episode #"
|
self.data.index.name = "Episode #"
|
||||||
if len(self.data.index) == 1:
|
if len(self.data.index) == 1:
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
|
||||||
self.data.to_csv(self.csv_path)
|
if os.path.exists(self.csv_path) and append:
|
||||||
|
self.data[self.last_line_idx_written_to_csv:].to_csv(self.csv_path, mode='a', header=False)
|
||||||
|
else:
|
||||||
|
self.data.to_csv(self.csv_path)
|
||||||
|
|
||||||
|
self.last_line_idx_written_to_csv = len(self.data.index)
|
||||||
|
|
||||||
def update_wall_clock_time(self, episode):
|
def update_wall_clock_time(self, episode):
|
||||||
if self.start_time:
|
if self.start_time:
|
||||||
self.create_signal_value('Wall-Clock Time', time.time() - self.start_time, time=episode)
|
self.create_signal_value('Wall-Clock Time', time.time() - self.start_time, time=episode)
|
||||||
|
else:
|
||||||
|
self.create_signal_value('Wall-Clock Time', time.time(), time=episode)
|
||||||
|
|
||||||
def create_gif(self, images, fps=10, name="Gif"):
|
def create_gif(self, images, fps=10, name="Gif"):
|
||||||
output_file = '{}_{}.gif'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), name)
|
output_file = '{}_{}.gif'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), name)
|
||||||
|
|||||||
Reference in New Issue
Block a user