1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

parallel trace tests fix

This commit is contained in:
Shadi Endrawis
2018-09-03 20:47:10 +03:00
parent 2c62a40466
commit 7086492127

View File

@@ -21,6 +21,7 @@ import shutil
import subprocess import subprocess
import multiprocessing import multiprocessing
import sys import sys
import signal
from importlib import import_module from importlib import import_module
from os import path from os import path
sys.path.append('.') sys.path.append('.')
@@ -30,6 +31,22 @@ import time
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from rl_coach.logger import screen from rl_coach.logger import screen
processes = []
def sigint_handler(signum, frame):
for proc in processes:
os.killpg(os.getpgid(proc[2].pid), signal.SIGTERM)
for f in os.listdir('experiments/'):
if '__test_trace' in f:
shutil.rmtree(os.path.join('experiments', f))
for f in os.listdir('.'):
if 'trace_test_log' in f:
os.remove(f)
exit()
signal.signal(signal.SIGINT, sigint_handler)
def read_csv_paths(test_path, filename_pattern, read_csv_tries=100): def read_csv_paths(test_path, filename_pattern, read_csv_tries=100):
csv_paths = [] csv_paths = []
@@ -57,7 +74,7 @@ def run_trace_based_test(preset_name, num_env_steps, level=None):
# run the experiment in a separate thread # run the experiment in a separate thread
screen.log_title("Running test {}{}".format(preset_name, ' - ' + level if level else '')) screen.log_title("Running test {}{}".format(preset_name, ' - ' + level if level else ''))
log_file_name = 'test_log_{preset_name}.txt'.format(preset_name=test_name[13:]) log_file_name = 'trace_test_log_{preset_name}.txt'.format(preset_name=test_name[13:])
cmd = ( cmd = (
'python3 rl_coach/coach.py ' 'python3 rl_coach/coach.py '
@@ -93,7 +110,6 @@ def wait_and_check(args, processes, force=False):
test_name = test_path.split('/')[-1] test_name = test_path.split('/')[-1]
log_file_name = processes[0][1] log_file_name = processes[0][1]
p = processes[0][2] p = processes[0][2]
processes.pop(0)
p.wait() p.wait()
filename_pattern = '*.csv' filename_pattern = '*.csv'
@@ -112,7 +128,7 @@ def wait_and_check(args, processes, force=False):
trace_path = os.path.join('./rl_coach', 'traces', test_name[13:]) trace_path = os.path.join('./rl_coach', 'traces', test_name[13:])
if not os.path.exists(trace_path): if not os.path.exists(trace_path):
screen.log('No trace found, creating new trace in: {}'.format(trace_path)) screen.log('No trace found, creating new trace in: {}'.format(trace_path))
os.makedirs(os.path.dirname(trace_path)) os.makedirs(trace_path)
df = pd.read_csv(csv_paths[0]) df = pd.read_csv(csv_paths[0])
df = clean_df(df) df = clean_df(df)
df.to_csv(os.path.join(trace_path, 'trace.csv'), index=False) df.to_csv(os.path.join(trace_path, 'trace.csv'), index=False)
@@ -143,6 +159,7 @@ def wait_and_check(args, processes, force=False):
shutil.rmtree(test_path) shutil.rmtree(test_path)
os.remove(log_file_name) os.remove(log_file_name)
processes.pop(0)
return test_passed return test_passed
@@ -185,7 +202,6 @@ def main():
fail_count = 0 fail_count = 0
test_count = 0 test_count = 0
processes = []
if args.ignore_presets is not None: if args.ignore_presets is not None:
presets_to_ignore = args.ignore_presets.split(',') presets_to_ignore = args.ignore_presets.split(',')