mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
appending csv's from logger instead of rewriting them
This commit is contained in:
@@ -164,10 +164,11 @@ class Agent(object):
|
||||
logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
|
||||
logger.create_signal_value('Total steps', self.total_steps_counter)
|
||||
logger.create_signal_value("Epsilon", self.exploration_policy.get_control_param())
|
||||
if phase == RunPhase.TRAIN:
|
||||
logger.create_signal_value("Training Reward", self.total_reward_in_current_episode)
|
||||
elif phase == RunPhase.TEST:
|
||||
logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode)
|
||||
logger.create_signal_value("Training Reward", self.total_reward_in_current_episode
|
||||
if phase == RunPhase.TRAIN else np.nan)
|
||||
logger.create_signal_value('Evaluation Reward', self.total_reward_in_current_episode
|
||||
if phase == RunPhase.TEST else np.nan)
|
||||
logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
||||
logger.update_wall_clock_time(self.current_episode)
|
||||
|
||||
for signal in self.signals:
|
||||
@@ -177,7 +178,8 @@ class Agent(object):
|
||||
logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
|
||||
|
||||
# dump
|
||||
if self.current_episode % self.tp.visualization.dump_signals_to_csv_every_x_episodes == 0:
|
||||
if self.current_episode % self.tp.visualization.dump_signals_to_csv_every_x_episodes == 0 \
|
||||
and self.current_episode > 0:
|
||||
logger.dump_output_csv()
|
||||
|
||||
def reset_game(self, do_not_reset_env=False):
|
||||
|
||||
Reference in New Issue
Block a user