1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding a summary when exiting coach

This commit is contained in:
Itai Caspi
2018-02-12 16:47:47 +02:00
committed by Itai Caspi
parent ba96e585d2
commit 5d1a2bc392
2 changed files with 92 additions and 47 deletions

View File

@@ -16,6 +16,7 @@
from pandas import *
import os
import re
from pprint import pprint
import threading
from subprocess import Popen, PIPE
@@ -23,6 +24,8 @@ import time
import datetime
from six.moves import input
from PIL import Image
from typing import Union
import shutil
global failed_imports
failed_imports = []
@@ -87,6 +90,31 @@ class ScreenLogger(object):
def ask_input(self, title):
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
def ask_yes_no(self, title: str, default: Union[None, bool]=None):
"""
Ask the user for a yes / no question and return True if the answer is yes and False otherwise.
The function will keep asking the user for an answer until he answers one of the possible responses.
A default answer can be passed and will be selected if the user presses enter
:param title: The question to ask the user
:param default: the default answer
:return: True / False according to the users answer
"""
default_answer = 'y/n'
if default == True:
default_answer = 'Y/n'
elif default == False:
default_answer = 'y/N'
while True:
answer = input("{}{}{} ({})".format(Colors.BG_CYAN, title, Colors.END, default_answer))
if answer == "yes" or answer == "YES" or answer == "y" or answer == "Y":
return True
elif answer == "no" or answer == "NO" or answer == "n" or answer == "N":
return False
elif answer == "":
if default is not None:
return default
class BaseLogger(object):
def __init__(self):
@@ -124,6 +152,7 @@ class Logger(BaseLogger):
self.csv_path = ''
self.doc_path = ''
self.aggregated_data_across_threads = None
self.time_started = datetime.datetime.now()
self.start_time = None
self.time = None
self.experiments_path = ""
@@ -143,7 +172,6 @@ class Logger(BaseLogger):
filename += "_{}".format(task_id)
# add timestamp
self.time_started = datetime.datetime.now()
if add_timestamp:
t = self.time_started.time()
d = self.time_started.date()
@@ -212,6 +240,65 @@ class Logger(BaseLogger):
pil_images = [Image.fromarray(image) for image in images]
pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0)
def remove_experiment_dir(self):
os.removedirs(self.experiments_path)
def print_summary(self):
screen.separator()
screen.log_title("Results stored at: {}".format(self.experiments_path))
screen.log_title("Total runtime: {}".format(datetime.datetime.now() - self.time_started))
if 'Training Reward' in self.data.keys() and 'Evaluation Reward' in self.data.keys():
screen.log_title("Max training reward: {}, max evaluation reward: {}".format(self.data['Training Reward'].max(), self.data['Evaluation Reward'].max()))
screen.separator()
if screen.ask_yes_no("Do you want to discard the experiment results (Warning: this cannot be undone)?", False):
self.remove_experiment_dir()
if screen.ask_yes_no("Do you want to specify a different experiment name to save to?", False):
new_name = self.get_experiment_name()
new_path = self.get_experiment_path(new_name, create_path=False)
shutil.move(self.experiments_path, new_path)
screen.log_title("Results moved to: {}".format(new_path))
def get_experiment_name(self, initial_experiment_name=''):
match = None
while match is None:
if initial_experiment_name == '':
experiment_name = screen.ask_input("Please enter an experiment name: ")
else:
experiment_name = initial_experiment_name
experiment_name = experiment_name.replace(" ", "_")
match = re.match("^$|^[\w -/]{1,100}$", experiment_name)
if match is None:
screen.error('Experiment name must be composed only of alphanumeric letters, '
'underscores and dashes and should not be longer than 100 characters.')
return match.group(0)
def get_experiment_path(self, experiment_name, create_path=True):
general_experiments_path = os.path.join('./experiments/', experiment_name)
cur_date = self.time_started.date()
cur_time = self.time_started.time()
if not os.path.exists(general_experiments_path) and create_path:
os.makedirs(general_experiments_path)
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}'
.format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month),
cur_date.year, logger.two_digits(cur_time.hour),
logger.two_digits(cur_time.minute)))
i = 0
while True:
if os.path.exists(experiment_path):
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}_{}'
.format(cur_date.day, cur_date.month, cur_date.year, cur_time.hour,
cur_time.minute, i))
i += 1
else:
if create_path:
os.makedirs(experiment_path)
return experiment_path
global logger
logger = Logger()