mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Adding a summary when exiting coach
This commit is contained in:
50
coach.py
50
coach.py
@@ -29,54 +29,11 @@ import argparse
|
|||||||
from subprocess import Popen
|
from subprocess import Popen
|
||||||
import datetime
|
import datetime
|
||||||
import presets
|
import presets
|
||||||
|
import atexit
|
||||||
|
|
||||||
if len(set(failed_imports)) > 0:
|
if len(set(failed_imports)) > 0:
|
||||||
screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
|
screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
|
||||||
|
|
||||||
time_started = datetime.datetime.now()
|
|
||||||
cur_time = time_started.time()
|
|
||||||
cur_date = time_started.date()
|
|
||||||
|
|
||||||
|
|
||||||
def get_experiment_name(initial_experiment_name=''):
|
|
||||||
match = None
|
|
||||||
while match is None:
|
|
||||||
if initial_experiment_name == '':
|
|
||||||
experiment_name = screen.ask_input("Please enter an experiment name: ")
|
|
||||||
else:
|
|
||||||
experiment_name = initial_experiment_name
|
|
||||||
|
|
||||||
experiment_name = experiment_name.replace(" ", "_")
|
|
||||||
match = re.match("^$|^[\w -/]{1,100}$", experiment_name)
|
|
||||||
|
|
||||||
if match is None:
|
|
||||||
screen.error('Experiment name must be composed only of alphanumeric letters, '
|
|
||||||
'underscores and dashes and should not be longer than 100 characters.')
|
|
||||||
|
|
||||||
return match.group(0)
|
|
||||||
|
|
||||||
|
|
||||||
def get_experiment_path(experiment_name, create_path=True):
|
|
||||||
general_experiments_path = os.path.join('./experiments/', experiment_name)
|
|
||||||
|
|
||||||
if not os.path.exists(general_experiments_path) and create_path:
|
|
||||||
os.makedirs(general_experiments_path)
|
|
||||||
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}'
|
|
||||||
.format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month),
|
|
||||||
cur_date.year, logger.two_digits(cur_time.hour),
|
|
||||||
logger.two_digits(cur_time.minute)))
|
|
||||||
i = 0
|
|
||||||
while True:
|
|
||||||
if os.path.exists(experiment_path):
|
|
||||||
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}_{}'
|
|
||||||
.format(cur_date.day, cur_date.month, cur_date.year, cur_time.hour,
|
|
||||||
cur_time.minute, i))
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
if create_path:
|
|
||||||
os.makedirs(experiment_path)
|
|
||||||
return experiment_path
|
|
||||||
|
|
||||||
|
|
||||||
def set_framework(framework_type):
|
def set_framework(framework_type):
|
||||||
# choosing neural network framework
|
# choosing neural network framework
|
||||||
@@ -146,8 +103,8 @@ def check_input_and_fill_run_dict(parser):
|
|||||||
args.exploration_policy_type = 'ExplorationParameters'
|
args.exploration_policy_type = 'ExplorationParameters'
|
||||||
|
|
||||||
# get experiment name and path
|
# get experiment name and path
|
||||||
experiment_name = get_experiment_name(args.experiment_name)
|
experiment_name = logger.get_experiment_name(args.experiment_name)
|
||||||
experiment_path = get_experiment_path(experiment_name)
|
experiment_path = logger.get_experiment_path(experiment_name)
|
||||||
|
|
||||||
if args.play and num_workers > 1:
|
if args.play and num_workers > 1:
|
||||||
screen.warning("Playing the game as a human is only available with a single worker. "
|
screen.warning("Playing the game as a human is only available with a single worker. "
|
||||||
@@ -289,6 +246,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# dump documentation
|
# dump documentation
|
||||||
logger.set_dump_dir(run_dict['experiment_path'], add_timestamp=True)
|
logger.set_dump_dir(run_dict['experiment_path'], add_timestamp=True)
|
||||||
|
atexit.register(logger.print_summary)
|
||||||
|
|
||||||
# Single-threaded runs
|
# Single-threaded runs
|
||||||
if run_dict['num_threads'] == 1:
|
if run_dict['num_threads'] == 1:
|
||||||
|
|||||||
89
logger.py
89
logger.py
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from pandas import *
|
from pandas import *
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
import threading
|
import threading
|
||||||
from subprocess import Popen, PIPE
|
from subprocess import Popen, PIPE
|
||||||
@@ -23,6 +24,8 @@ import time
|
|||||||
import datetime
|
import datetime
|
||||||
from six.moves import input
|
from six.moves import input
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
|
import shutil
|
||||||
|
|
||||||
global failed_imports
|
global failed_imports
|
||||||
failed_imports = []
|
failed_imports = []
|
||||||
@@ -87,6 +90,31 @@ class ScreenLogger(object):
|
|||||||
def ask_input(self, title):
|
def ask_input(self, title):
|
||||||
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
|
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
|
||||||
|
|
||||||
|
def ask_yes_no(self, title: str, default: Union[None, bool]=None):
|
||||||
|
"""
|
||||||
|
Ask the user for a yes / no question and return True if the answer is yes and False otherwise.
|
||||||
|
The function will keep asking the user for an answer until he answers one of the possible responses.
|
||||||
|
A default answer can be passed and will be selected if the user presses enter
|
||||||
|
:param title: The question to ask the user
|
||||||
|
:param default: the default answer
|
||||||
|
:return: True / False according to the users answer
|
||||||
|
"""
|
||||||
|
default_answer = 'y/n'
|
||||||
|
if default == True:
|
||||||
|
default_answer = 'Y/n'
|
||||||
|
elif default == False:
|
||||||
|
default_answer = 'y/N'
|
||||||
|
|
||||||
|
while True:
|
||||||
|
answer = input("{}{}{} ({})".format(Colors.BG_CYAN, title, Colors.END, default_answer))
|
||||||
|
if answer == "yes" or answer == "YES" or answer == "y" or answer == "Y":
|
||||||
|
return True
|
||||||
|
elif answer == "no" or answer == "NO" or answer == "n" or answer == "N":
|
||||||
|
return False
|
||||||
|
elif answer == "":
|
||||||
|
if default is not None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
class BaseLogger(object):
|
class BaseLogger(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -124,6 +152,7 @@ class Logger(BaseLogger):
|
|||||||
self.csv_path = ''
|
self.csv_path = ''
|
||||||
self.doc_path = ''
|
self.doc_path = ''
|
||||||
self.aggregated_data_across_threads = None
|
self.aggregated_data_across_threads = None
|
||||||
|
self.time_started = datetime.datetime.now()
|
||||||
self.start_time = None
|
self.start_time = None
|
||||||
self.time = None
|
self.time = None
|
||||||
self.experiments_path = ""
|
self.experiments_path = ""
|
||||||
@@ -143,7 +172,6 @@ class Logger(BaseLogger):
|
|||||||
filename += "_{}".format(task_id)
|
filename += "_{}".format(task_id)
|
||||||
|
|
||||||
# add timestamp
|
# add timestamp
|
||||||
self.time_started = datetime.datetime.now()
|
|
||||||
if add_timestamp:
|
if add_timestamp:
|
||||||
t = self.time_started.time()
|
t = self.time_started.time()
|
||||||
d = self.time_started.date()
|
d = self.time_started.date()
|
||||||
@@ -212,6 +240,65 @@ class Logger(BaseLogger):
|
|||||||
pil_images = [Image.fromarray(image) for image in images]
|
pil_images = [Image.fromarray(image) for image in images]
|
||||||
pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0)
|
pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0)
|
||||||
|
|
||||||
|
def remove_experiment_dir(self):
|
||||||
|
os.removedirs(self.experiments_path)
|
||||||
|
|
||||||
|
def print_summary(self):
|
||||||
|
screen.separator()
|
||||||
|
screen.log_title("Results stored at: {}".format(self.experiments_path))
|
||||||
|
screen.log_title("Total runtime: {}".format(datetime.datetime.now() - self.time_started))
|
||||||
|
if 'Training Reward' in self.data.keys() and 'Evaluation Reward' in self.data.keys():
|
||||||
|
screen.log_title("Max training reward: {}, max evaluation reward: {}".format(self.data['Training Reward'].max(), self.data['Evaluation Reward'].max()))
|
||||||
|
screen.separator()
|
||||||
|
if screen.ask_yes_no("Do you want to discard the experiment results (Warning: this cannot be undone)?", False):
|
||||||
|
self.remove_experiment_dir()
|
||||||
|
if screen.ask_yes_no("Do you want to specify a different experiment name to save to?", False):
|
||||||
|
new_name = self.get_experiment_name()
|
||||||
|
new_path = self.get_experiment_path(new_name, create_path=False)
|
||||||
|
shutil.move(self.experiments_path, new_path)
|
||||||
|
screen.log_title("Results moved to: {}".format(new_path))
|
||||||
|
|
||||||
|
def get_experiment_name(self, initial_experiment_name=''):
|
||||||
|
match = None
|
||||||
|
while match is None:
|
||||||
|
if initial_experiment_name == '':
|
||||||
|
experiment_name = screen.ask_input("Please enter an experiment name: ")
|
||||||
|
else:
|
||||||
|
experiment_name = initial_experiment_name
|
||||||
|
|
||||||
|
experiment_name = experiment_name.replace(" ", "_")
|
||||||
|
match = re.match("^$|^[\w -/]{1,100}$", experiment_name)
|
||||||
|
|
||||||
|
if match is None:
|
||||||
|
screen.error('Experiment name must be composed only of alphanumeric letters, '
|
||||||
|
'underscores and dashes and should not be longer than 100 characters.')
|
||||||
|
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
def get_experiment_path(self, experiment_name, create_path=True):
|
||||||
|
general_experiments_path = os.path.join('./experiments/', experiment_name)
|
||||||
|
|
||||||
|
cur_date = self.time_started.date()
|
||||||
|
cur_time = self.time_started.time()
|
||||||
|
|
||||||
|
if not os.path.exists(general_experiments_path) and create_path:
|
||||||
|
os.makedirs(general_experiments_path)
|
||||||
|
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}'
|
||||||
|
.format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month),
|
||||||
|
cur_date.year, logger.two_digits(cur_time.hour),
|
||||||
|
logger.two_digits(cur_time.minute)))
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
if os.path.exists(experiment_path):
|
||||||
|
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}_{}'
|
||||||
|
.format(cur_date.day, cur_date.month, cur_date.year, cur_time.hour,
|
||||||
|
cur_time.minute, i))
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
if create_path:
|
||||||
|
os.makedirs(experiment_path)
|
||||||
|
return experiment_path
|
||||||
|
|
||||||
|
|
||||||
global logger
|
global logger
|
||||||
logger = Logger()
|
logger = Logger()
|
||||||
|
|||||||
Reference in New Issue
Block a user