From 7086492127087bca426afd85d1797bf71ab22b8b Mon Sep 17 00:00:00 2001 From: Shadi Endrawis Date: Mon, 3 Sep 2018 20:47:10 +0300 Subject: [PATCH] parallel trace tests fix --- rl_coach/tests/trace_tests.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/rl_coach/tests/trace_tests.py b/rl_coach/tests/trace_tests.py index b9f5647..ca90ea9 100644 --- a/rl_coach/tests/trace_tests.py +++ b/rl_coach/tests/trace_tests.py @@ -21,6 +21,7 @@ import shutil import subprocess import multiprocessing import sys +import signal from importlib import import_module from os import path sys.path.append('.') @@ -30,6 +31,22 @@ import time # -*- coding: utf-8 -*- from rl_coach.logger import screen +processes = [] + + +def sigint_handler(signum, frame): + for proc in processes: + os.killpg(os.getpgid(proc[2].pid), signal.SIGTERM) + for f in os.listdir('experiments/'): + if '__test_trace' in f: + shutil.rmtree(os.path.join('experiments', f)) + for f in os.listdir('.'): + if 'trace_test_log' in f: + os.remove(f) + exit() + +signal.signal(signal.SIGINT, sigint_handler) + def read_csv_paths(test_path, filename_pattern, read_csv_tries=100): csv_paths = [] @@ -57,7 +74,7 @@ def run_trace_based_test(preset_name, num_env_steps, level=None): # run the experiment in a separate thread screen.log_title("Running test {}{}".format(preset_name, ' - ' + level if level else '')) - log_file_name = 'test_log_{preset_name}.txt'.format(preset_name=test_name[13:]) + log_file_name = 'trace_test_log_{preset_name}.txt'.format(preset_name=test_name[13:]) cmd = ( 'python3 rl_coach/coach.py ' @@ -93,7 +110,6 @@ def wait_and_check(args, processes, force=False): test_name = test_path.split('/')[-1] log_file_name = processes[0][1] p = processes[0][2] - processes.pop(0) p.wait() filename_pattern = '*.csv' @@ -112,7 +128,7 @@ def wait_and_check(args, processes, force=False): trace_path = os.path.join('./rl_coach', 'traces', test_name[13:]) if not os.path.exists(trace_path): screen.log('No trace found, creating new trace in: {}'.format(trace_path)) - os.makedirs(os.path.dirname(trace_path)) + os.makedirs(trace_path) df = pd.read_csv(csv_paths[0]) df = clean_df(df) df.to_csv(os.path.join(trace_path, 'trace.csv'), index=False) @@ -143,6 +159,7 @@ def wait_and_check(args, processes, force=False): shutil.rmtree(test_path) os.remove(log_file_name) + processes.pop(0) return test_passed @@ -185,7 +202,6 @@ def main(): fail_count = 0 test_count = 0 - processes = [] if args.ignore_presets is not None: presets_to_ignore = args.ignore_presets.split(',')