mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
parallel trace tests fix
This commit is contained in:
@@ -21,6 +21,7 @@ import shutil
|
||||
import subprocess
|
||||
import multiprocessing
|
||||
import sys
|
||||
import signal
|
||||
from importlib import import_module
|
||||
from os import path
|
||||
sys.path.append('.')
|
||||
@@ -30,6 +31,22 @@ import time
|
||||
# -*- coding: utf-8 -*-
|
||||
from rl_coach.logger import screen
|
||||
|
||||
processes = []
|
||||
|
||||
|
||||
def sigint_handler(signum, frame):
|
||||
for proc in processes:
|
||||
os.killpg(os.getpgid(proc[2].pid), signal.SIGTERM)
|
||||
for f in os.listdir('experiments/'):
|
||||
if '__test_trace' in f:
|
||||
shutil.rmtree(os.path.join('experiments', f))
|
||||
for f in os.listdir('.'):
|
||||
if 'trace_test_log' in f:
|
||||
os.remove(f)
|
||||
exit()
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
|
||||
def read_csv_paths(test_path, filename_pattern, read_csv_tries=100):
|
||||
csv_paths = []
|
||||
@@ -57,7 +74,7 @@ def run_trace_based_test(preset_name, num_env_steps, level=None):
|
||||
|
||||
# run the experiment in a separate thread
|
||||
screen.log_title("Running test {}{}".format(preset_name, ' - ' + level if level else ''))
|
||||
log_file_name = 'test_log_{preset_name}.txt'.format(preset_name=test_name[13:])
|
||||
log_file_name = 'trace_test_log_{preset_name}.txt'.format(preset_name=test_name[13:])
|
||||
|
||||
cmd = (
|
||||
'python3 rl_coach/coach.py '
|
||||
@@ -93,7 +110,6 @@ def wait_and_check(args, processes, force=False):
|
||||
test_name = test_path.split('/')[-1]
|
||||
log_file_name = processes[0][1]
|
||||
p = processes[0][2]
|
||||
processes.pop(0)
|
||||
p.wait()
|
||||
|
||||
filename_pattern = '*.csv'
|
||||
@@ -112,7 +128,7 @@ def wait_and_check(args, processes, force=False):
|
||||
trace_path = os.path.join('./rl_coach', 'traces', test_name[13:])
|
||||
if not os.path.exists(trace_path):
|
||||
screen.log('No trace found, creating new trace in: {}'.format(trace_path))
|
||||
os.makedirs(os.path.dirname(trace_path))
|
||||
os.makedirs(trace_path)
|
||||
df = pd.read_csv(csv_paths[0])
|
||||
df = clean_df(df)
|
||||
df.to_csv(os.path.join(trace_path, 'trace.csv'), index=False)
|
||||
@@ -143,6 +159,7 @@ def wait_and_check(args, processes, force=False):
|
||||
|
||||
shutil.rmtree(test_path)
|
||||
os.remove(log_file_name)
|
||||
processes.pop(0)
|
||||
return test_passed
|
||||
|
||||
|
||||
@@ -185,7 +202,6 @@ def main():
|
||||
|
||||
fail_count = 0
|
||||
test_count = 0
|
||||
processes = []
|
||||
|
||||
if args.ignore_presets is not None:
|
||||
presets_to_ignore = args.ignore_presets.split(',')
|
||||
|
||||
Reference in New Issue
Block a user