mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
parallel trace tests fix
This commit is contained in:
@@ -21,6 +21,7 @@ import shutil
|
|||||||
import subprocess
|
import subprocess
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import sys
|
import sys
|
||||||
|
import signal
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from os import path
|
from os import path
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
@@ -30,6 +31,22 @@ import time
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from rl_coach.logger import screen
|
from rl_coach.logger import screen
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
|
||||||
|
def sigint_handler(signum, frame):
|
||||||
|
for proc in processes:
|
||||||
|
os.killpg(os.getpgid(proc[2].pid), signal.SIGTERM)
|
||||||
|
for f in os.listdir('experiments/'):
|
||||||
|
if '__test_trace' in f:
|
||||||
|
shutil.rmtree(os.path.join('experiments', f))
|
||||||
|
for f in os.listdir('.'):
|
||||||
|
if 'trace_test_log' in f:
|
||||||
|
os.remove(f)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
|
|
||||||
def read_csv_paths(test_path, filename_pattern, read_csv_tries=100):
|
def read_csv_paths(test_path, filename_pattern, read_csv_tries=100):
|
||||||
csv_paths = []
|
csv_paths = []
|
||||||
@@ -57,7 +74,7 @@ def run_trace_based_test(preset_name, num_env_steps, level=None):
|
|||||||
|
|
||||||
# run the experiment in a separate thread
|
# run the experiment in a separate thread
|
||||||
screen.log_title("Running test {}{}".format(preset_name, ' - ' + level if level else ''))
|
screen.log_title("Running test {}{}".format(preset_name, ' - ' + level if level else ''))
|
||||||
log_file_name = 'test_log_{preset_name}.txt'.format(preset_name=test_name[13:])
|
log_file_name = 'trace_test_log_{preset_name}.txt'.format(preset_name=test_name[13:])
|
||||||
|
|
||||||
cmd = (
|
cmd = (
|
||||||
'python3 rl_coach/coach.py '
|
'python3 rl_coach/coach.py '
|
||||||
@@ -93,7 +110,6 @@ def wait_and_check(args, processes, force=False):
|
|||||||
test_name = test_path.split('/')[-1]
|
test_name = test_path.split('/')[-1]
|
||||||
log_file_name = processes[0][1]
|
log_file_name = processes[0][1]
|
||||||
p = processes[0][2]
|
p = processes[0][2]
|
||||||
processes.pop(0)
|
|
||||||
p.wait()
|
p.wait()
|
||||||
|
|
||||||
filename_pattern = '*.csv'
|
filename_pattern = '*.csv'
|
||||||
@@ -112,7 +128,7 @@ def wait_and_check(args, processes, force=False):
|
|||||||
trace_path = os.path.join('./rl_coach', 'traces', test_name[13:])
|
trace_path = os.path.join('./rl_coach', 'traces', test_name[13:])
|
||||||
if not os.path.exists(trace_path):
|
if not os.path.exists(trace_path):
|
||||||
screen.log('No trace found, creating new trace in: {}'.format(trace_path))
|
screen.log('No trace found, creating new trace in: {}'.format(trace_path))
|
||||||
os.makedirs(os.path.dirname(trace_path))
|
os.makedirs(trace_path)
|
||||||
df = pd.read_csv(csv_paths[0])
|
df = pd.read_csv(csv_paths[0])
|
||||||
df = clean_df(df)
|
df = clean_df(df)
|
||||||
df.to_csv(os.path.join(trace_path, 'trace.csv'), index=False)
|
df.to_csv(os.path.join(trace_path, 'trace.csv'), index=False)
|
||||||
@@ -143,6 +159,7 @@ def wait_and_check(args, processes, force=False):
|
|||||||
|
|
||||||
shutil.rmtree(test_path)
|
shutil.rmtree(test_path)
|
||||||
os.remove(log_file_name)
|
os.remove(log_file_name)
|
||||||
|
processes.pop(0)
|
||||||
return test_passed
|
return test_passed
|
||||||
|
|
||||||
|
|
||||||
@@ -185,7 +202,6 @@ def main():
|
|||||||
|
|
||||||
fail_count = 0
|
fail_count = 0
|
||||||
test_count = 0
|
test_count = 0
|
||||||
processes = []
|
|
||||||
|
|
||||||
if args.ignore_presets is not None:
|
if args.ignore_presets is not None:
|
||||||
presets_to_ignore = args.ignore_presets.split(',')
|
presets_to_ignore = args.ignore_presets.split(',')
|
||||||
|
|||||||
Reference in New Issue
Block a user