1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
This commit is contained in:
Gal Leibovich
2019-03-19 18:07:09 +02:00
committed by GitHub
parent 4a8451ff02
commit e3c7e526c7
38 changed files with 1003 additions and 87 deletions

View File

@@ -185,6 +185,9 @@ class BaseLogger(object):
self.time = time
def create_signal_value(self, signal_name, value, overwrite=True, time=None):
if self.index_name == signal_name:
return False # make sure that we don't create duplicate signals
if self.last_line_idx_written_to_csv != 0:
assert signal_name in self.data.columns
@@ -227,12 +230,15 @@ class BaseLogger(object):
self.last_line_idx_written_to_csv = len(self.data.index)
def update_wall_clock_time(self, index):
def get_current_wall_clock_time(self):
if self.start_time:
self.create_signal_value('Wall-Clock Time', time.time() - self.start_time, time=index)
return time.time() - self.start_time
else:
self.create_signal_value('Wall-Clock Time', 0, time=index)
self.start_time = time.time()
return 0
def update_wall_clock_time(self, index):
self.create_signal_value('Wall-Clock Time', self.get_current_wall_clock_time(), time=index)
class EpisodeLogger(BaseLogger):
@@ -263,10 +269,13 @@ class EpisodeLogger(BaseLogger):
class Logger(BaseLogger):
def __init__(self):
def __init__(self, index_name='Episode #'):
super().__init__()
self.doc_path = ''
self.index_name = 'Episode #'
self.index_name = index_name
def set_index_name(self, index_name):
self.index_name = index_name
def set_logger_filenames(self, _experiments_path, logger_prefix='', task_id=None, add_timestamp=False, filename=''):
self.experiments_path = _experiments_path