diff --git a/coach.py b/coach.py index e02e702..b8b16be 100644 --- a/coach.py +++ b/coach.py @@ -29,54 +29,11 @@ import argparse from subprocess import Popen import datetime import presets +import atexit if len(set(failed_imports)) > 0: screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports)))) -time_started = datetime.datetime.now() -cur_time = time_started.time() -cur_date = time_started.date() - - -def get_experiment_name(initial_experiment_name=''): - match = None - while match is None: - if initial_experiment_name == '': - experiment_name = screen.ask_input("Please enter an experiment name: ") - else: - experiment_name = initial_experiment_name - - experiment_name = experiment_name.replace(" ", "_") - match = re.match("^$|^[\w -/]{1,100}$", experiment_name) - - if match is None: - screen.error('Experiment name must be composed only of alphanumeric letters, ' - 'underscores and dashes and should not be longer than 100 characters.') - - return match.group(0) - - -def get_experiment_path(experiment_name, create_path=True): - general_experiments_path = os.path.join('./experiments/', experiment_name) - - if not os.path.exists(general_experiments_path) and create_path: - os.makedirs(general_experiments_path) - experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}' - .format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month), - cur_date.year, logger.two_digits(cur_time.hour), - logger.two_digits(cur_time.minute))) - i = 0 - while True: - if os.path.exists(experiment_path): - experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}_{}' - .format(cur_date.day, cur_date.month, cur_date.year, cur_time.hour, - cur_time.minute, i)) - i += 1 - else: - if create_path: - os.makedirs(experiment_path) - return experiment_path - def set_framework(framework_type): # choosing neural network framework @@ -146,8 +103,8 @@ def check_input_and_fill_run_dict(parser): args.exploration_policy_type = 'ExplorationParameters' # get experiment name and path - experiment_name = get_experiment_name(args.experiment_name) - experiment_path = get_experiment_path(experiment_name) + experiment_name = logger.get_experiment_name(args.experiment_name) + experiment_path = logger.get_experiment_path(experiment_name) if args.play and num_workers > 1: screen.warning("Playing the game as a human is only available with a single worker. " @@ -289,6 +246,7 @@ if __name__ == "__main__": # dump documentation logger.set_dump_dir(run_dict['experiment_path'], add_timestamp=True) + atexit.register(logger.print_summary) # Single-threaded runs if run_dict['num_threads'] == 1: diff --git a/logger.py b/logger.py index ce07bee..b5a1302 100644 --- a/logger.py +++ b/logger.py @@ -16,6 +16,7 @@ from pandas import * import os +import re from pprint import pprint import threading from subprocess import Popen, PIPE @@ -23,6 +24,8 @@ import time import datetime from six.moves import input from PIL import Image +from typing import Union +import shutil global failed_imports failed_imports = [] @@ -87,6 +90,31 @@ class ScreenLogger(object): def ask_input(self, title): return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END)) + def ask_yes_no(self, title: str, default: Union[None, bool]=None): + """ + Ask the user for a yes / no question and return True if the answer is yes and False otherwise. + The function will keep asking the user for an answer until he answers one of the possible responses. + A default answer can be passed and will be selected if the user presses enter + :param title: The question to ask the user + :param default: the default answer + :return: True / False according to the users answer + """ + default_answer = 'y/n' + if default == True: + default_answer = 'Y/n' + elif default == False: + default_answer = 'y/N' + + while True: + answer = input("{}{}{} ({})".format(Colors.BG_CYAN, title, Colors.END, default_answer)) + if answer == "yes" or answer == "YES" or answer == "y" or answer == "Y": + return True + elif answer == "no" or answer == "NO" or answer == "n" or answer == "N": + return False + elif answer == "": + if default is not None: + return default + class BaseLogger(object): def __init__(self): @@ -124,6 +152,7 @@ class Logger(BaseLogger): self.csv_path = '' self.doc_path = '' self.aggregated_data_across_threads = None + self.time_started = datetime.datetime.now() self.start_time = None self.time = None self.experiments_path = "" @@ -143,7 +172,6 @@ class Logger(BaseLogger): filename += "_{}".format(task_id) # add timestamp - self.time_started = datetime.datetime.now() if add_timestamp: t = self.time_started.time() d = self.time_started.date() @@ -212,6 +240,65 @@ class Logger(BaseLogger): pil_images = [Image.fromarray(image) for image in images] pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0) + def remove_experiment_dir(self): + os.removedirs(self.experiments_path) + + def print_summary(self): + screen.separator() + screen.log_title("Results stored at: {}".format(self.experiments_path)) + screen.log_title("Total runtime: {}".format(datetime.datetime.now() - self.time_started)) + if 'Training Reward' in self.data.keys() and 'Evaluation Reward' in self.data.keys(): + screen.log_title("Max training reward: {}, max evaluation reward: {}".format(self.data['Training Reward'].max(), self.data['Evaluation Reward'].max())) + screen.separator() + if screen.ask_yes_no("Do you want to discard the experiment results (Warning: this cannot be undone)?", False): + self.remove_experiment_dir() + if screen.ask_yes_no("Do you want to specify a different experiment name to save to?", False): + new_name = self.get_experiment_name() + new_path = self.get_experiment_path(new_name, create_path=False) + shutil.move(self.experiments_path, new_path) + screen.log_title("Results moved to: {}".format(new_path)) + + def get_experiment_name(self, initial_experiment_name=''): + match = None + while match is None: + if initial_experiment_name == '': + experiment_name = screen.ask_input("Please enter an experiment name: ") + else: + experiment_name = initial_experiment_name + + experiment_name = experiment_name.replace(" ", "_") + match = re.match("^$|^[\w -/]{1,100}$", experiment_name) + + if match is None: + screen.error('Experiment name must be composed only of alphanumeric letters, ' + 'underscores and dashes and should not be longer than 100 characters.') + + return match.group(0) + + def get_experiment_path(self, experiment_name, create_path=True): + general_experiments_path = os.path.join('./experiments/', experiment_name) + + cur_date = self.time_started.date() + cur_time = self.time_started.time() + + if not os.path.exists(general_experiments_path) and create_path: + os.makedirs(general_experiments_path) + experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}' + .format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month), + cur_date.year, logger.two_digits(cur_time.hour), + logger.two_digits(cur_time.minute))) + i = 0 + while True: + if os.path.exists(experiment_path): + experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}_{}' + .format(cur_date.day, cur_date.month, cur_date.year, cur_time.hour, + cur_time.minute, i)) + i += 1 + else: + if create_path: + os.makedirs(experiment_path) + return experiment_path + global logger logger = Logger()