mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Batch RL (#238)
This commit is contained in:
@@ -185,6 +185,9 @@ class BaseLogger(object):
|
||||
self.time = time
|
||||
|
||||
def create_signal_value(self, signal_name, value, overwrite=True, time=None):
|
||||
if self.index_name == signal_name:
|
||||
return False # make sure that we don't create duplicate signals
|
||||
|
||||
if self.last_line_idx_written_to_csv != 0:
|
||||
assert signal_name in self.data.columns
|
||||
|
||||
@@ -227,12 +230,15 @@ class BaseLogger(object):
|
||||
|
||||
self.last_line_idx_written_to_csv = len(self.data.index)
|
||||
|
||||
def update_wall_clock_time(self, index):
|
||||
def get_current_wall_clock_time(self):
|
||||
if self.start_time:
|
||||
self.create_signal_value('Wall-Clock Time', time.time() - self.start_time, time=index)
|
||||
return time.time() - self.start_time
|
||||
else:
|
||||
self.create_signal_value('Wall-Clock Time', 0, time=index)
|
||||
self.start_time = time.time()
|
||||
return 0
|
||||
|
||||
def update_wall_clock_time(self, index):
|
||||
self.create_signal_value('Wall-Clock Time', self.get_current_wall_clock_time(), time=index)
|
||||
|
||||
|
||||
class EpisodeLogger(BaseLogger):
|
||||
@@ -263,10 +269,13 @@ class EpisodeLogger(BaseLogger):
|
||||
|
||||
|
||||
class Logger(BaseLogger):
|
||||
def __init__(self):
|
||||
def __init__(self, index_name='Episode #'):
|
||||
super().__init__()
|
||||
self.doc_path = ''
|
||||
self.index_name = 'Episode #'
|
||||
self.index_name = index_name
|
||||
|
||||
def set_index_name(self, index_name):
|
||||
self.index_name = index_name
|
||||
|
||||
def set_logger_filenames(self, _experiments_path, logger_prefix='', task_id=None, add_timestamp=False, filename=''):
|
||||
self.experiments_path = _experiments_path
|
||||
|
||||
Reference in New Issue
Block a user