mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Multiple improvements and bug fixes (#66)
* Multiple improvements and bug fixes:
* Using lazy stacking to save on memory when using a replay buffer
* Remove step counting for evaluation episodes
* Reset game between heatup and training
* Major bug fixes in NEC (is reproducing the paper results for pong now)
* Image input rescaling to 0-1 is now optional
* Change the terminal title to be the experiment name
* Observation cropping for atari is now optional
* Added random number of noop actions for gym to match the dqn paper
* Fixed a bug where the evaluation episodes won't start with the max possible ale lives
* Added a script for plotting the results of an experiment over all the atari games
This commit is contained in:
21
logger.py
21
logger.py
@@ -115,6 +115,14 @@ class ScreenLogger(object):
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
def change_terminal_title(self, title: str):
|
||||
"""
|
||||
Changes the title of the terminal window
|
||||
:param title: The new title
|
||||
:return: None
|
||||
"""
|
||||
print("\x1b]2;{}\x07".format(title))
|
||||
|
||||
|
||||
class BaseLogger(object):
|
||||
def __init__(self):
|
||||
@@ -157,6 +165,7 @@ class Logger(BaseLogger):
|
||||
self.time = None
|
||||
self.experiments_path = ""
|
||||
self.last_line_idx_written_to_csv = 0
|
||||
self.experiment_name = ""
|
||||
|
||||
def set_current_time(self, time):
|
||||
self.time = time
|
||||
@@ -205,7 +214,9 @@ class Logger(BaseLogger):
|
||||
|
||||
def signal_value_exists(self, time, signal_name):
|
||||
try:
|
||||
self.get_signal_value(time, signal_name)
|
||||
value = self.get_signal_value(time, signal_name)
|
||||
if value != value: # value is nan
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
@@ -229,7 +240,8 @@ class Logger(BaseLogger):
|
||||
if self.start_time:
|
||||
self.create_signal_value('Wall-Clock Time', time.time() - self.start_time, time=episode)
|
||||
else:
|
||||
self.create_signal_value('Wall-Clock Time', time.time(), time=episode)
|
||||
self.create_signal_value('Wall-Clock Time', 0, time=episode)
|
||||
self.start_time = time.time()
|
||||
|
||||
def create_gif(self, images, fps=10, name="Gif"):
|
||||
output_file = '{}_{}.gif'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), name)
|
||||
@@ -243,7 +255,7 @@ class Logger(BaseLogger):
|
||||
def remove_experiment_dir(self):
|
||||
shutil.rmtree(self.experiments_path)
|
||||
|
||||
def print_summary(self):
|
||||
def summarize_experiment(self):
|
||||
screen.separator()
|
||||
screen.log_title("Results stored at: {}".format(self.experiments_path))
|
||||
screen.log_title("Total runtime: {}".format(datetime.datetime.now() - self.time_started))
|
||||
@@ -273,7 +285,8 @@ class Logger(BaseLogger):
|
||||
screen.error('Experiment name must be composed only of alphanumeric letters, '
|
||||
'underscores and dashes and should not be longer than 100 characters.')
|
||||
|
||||
return match.group(0)
|
||||
self.experiment_name = match.group(0)
|
||||
return self.experiment_name
|
||||
|
||||
def get_experiment_path(self, experiment_name, create_path=True):
|
||||
general_experiments_path = os.path.join('./experiments/', experiment_name)
|
||||
|
||||
Reference in New Issue
Block a user