1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding a summary when exiting coach

This commit is contained in:
Itai Caspi
2018-02-12 16:47:47 +02:00
committed by Itai Caspi
parent ba96e585d2
commit 5d1a2bc392
2 changed files with 92 additions and 47 deletions

View File

@@ -29,54 +29,11 @@ import argparse
from subprocess import Popen
import datetime
import presets
import atexit
if len(set(failed_imports)) > 0:
screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
time_started = datetime.datetime.now()
cur_time = time_started.time()
cur_date = time_started.date()
def get_experiment_name(initial_experiment_name=''):
match = None
while match is None:
if initial_experiment_name == '':
experiment_name = screen.ask_input("Please enter an experiment name: ")
else:
experiment_name = initial_experiment_name
experiment_name = experiment_name.replace(" ", "_")
match = re.match("^$|^[\w -/]{1,100}$", experiment_name)
if match is None:
screen.error('Experiment name must be composed only of alphanumeric letters, '
'underscores and dashes and should not be longer than 100 characters.')
return match.group(0)
def get_experiment_path(experiment_name, create_path=True):
general_experiments_path = os.path.join('./experiments/', experiment_name)
if not os.path.exists(general_experiments_path) and create_path:
os.makedirs(general_experiments_path)
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}'
.format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month),
cur_date.year, logger.two_digits(cur_time.hour),
logger.two_digits(cur_time.minute)))
i = 0
while True:
if os.path.exists(experiment_path):
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}_{}'
.format(cur_date.day, cur_date.month, cur_date.year, cur_time.hour,
cur_time.minute, i))
i += 1
else:
if create_path:
os.makedirs(experiment_path)
return experiment_path
def set_framework(framework_type):
# choosing neural network framework
@@ -146,8 +103,8 @@ def check_input_and_fill_run_dict(parser):
args.exploration_policy_type = 'ExplorationParameters'
# get experiment name and path
experiment_name = get_experiment_name(args.experiment_name)
experiment_path = get_experiment_path(experiment_name)
experiment_name = logger.get_experiment_name(args.experiment_name)
experiment_path = logger.get_experiment_path(experiment_name)
if args.play and num_workers > 1:
screen.warning("Playing the game as a human is only available with a single worker. "
@@ -289,6 +246,7 @@ if __name__ == "__main__":
# dump documentation
logger.set_dump_dir(run_dict['experiment_path'], add_timestamp=True)
atexit.register(logger.print_summary)
# Single-threaded runs
if run_dict['num_threads'] == 1: