mirror of
https://github.com/gryf/coach.git
synced 2026-04-02 01:53:33 +02:00
Running trace tests in parallel + other small fixes
This commit is contained in:
@@ -44,12 +44,6 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=100):
|
||||
return csv_paths
|
||||
|
||||
|
||||
def clean_df(df):
|
||||
if 'Wall-Clock Time' in df.keys():
|
||||
df.drop(['Wall-Clock Time'], 1, inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
def print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args):
|
||||
percentage = int((100 * last_num_episodes) / preset_validation_params.max_episodes_to_achieve_reward)
|
||||
sys.stdout.write("\rReward: ({}/{})".format(round(averaged_rewards[-1], 1),
|
||||
@@ -186,94 +180,8 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
||||
return test_passed
|
||||
|
||||
|
||||
def perform_trace_based_tests(args, preset_name, num_env_steps, level=None):
|
||||
test_name = '__test_trace'
|
||||
test_path = os.path.join('./experiments', test_name)
|
||||
if path.exists(test_path):
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
# run the experiment in a separate thread
|
||||
screen.log_title("Running test {}{}".format(preset_name, ' - ' + level if level else ''))
|
||||
log_file_name = 'test_log_{preset_name}.txt'.format(preset_name=preset_name)
|
||||
|
||||
cmd = (
|
||||
'python3 rl_coach/coach.py '
|
||||
'-p {preset_name} '
|
||||
'-e {test_name} '
|
||||
'--seed 42 '
|
||||
'-c '
|
||||
'--no_summary '
|
||||
'-cp {custom_param} '
|
||||
'{level} '
|
||||
'&> {log_file_name} '
|
||||
).format(
|
||||
preset_name=preset_name,
|
||||
test_name=test_name,
|
||||
log_file_name=log_file_name,
|
||||
level='-lvl ' + level if level else '',
|
||||
custom_param='\"improve_steps=EnvironmentSteps({n});'
|
||||
'steps_between_evaluation_periods=EnvironmentSteps({n});'
|
||||
'evaluation_steps=EnvironmentSteps(1);'
|
||||
'heatup_steps=EnvironmentSteps(1024)\"'.format(n=num_env_steps)
|
||||
)
|
||||
|
||||
p = subprocess.Popen(cmd, shell=True, executable="/bin/bash", preexec_fn=os.setsid)
|
||||
p.wait()
|
||||
|
||||
filename_pattern = '*.csv'
|
||||
|
||||
# get the csv with the results
|
||||
csv_paths = read_csv_paths(test_path, filename_pattern)
|
||||
|
||||
test_passed = False
|
||||
if not csv_paths:
|
||||
screen.error("csv file never found", crash=False)
|
||||
if args.verbose:
|
||||
screen.error("command exitcode: {}".format(p.returncode), crash=False)
|
||||
screen.error(open(log_file_name).read(), crash=False)
|
||||
else:
|
||||
trace_path = os.path.join('./rl_coach', 'traces', preset_name + '_' + level.replace(':', '_') if level else preset_name, '')
|
||||
if not os.path.exists(trace_path):
|
||||
screen.log('No trace found, creating new trace in: {}'.format(trace_path))
|
||||
os.makedirs(os.path.dirname(trace_path))
|
||||
df = pd.read_csv(csv_paths[0])
|
||||
df = clean_df(df)
|
||||
df.to_csv(os.path.join(trace_path, 'trace.csv'), index=False)
|
||||
screen.success("Successfully created new trace.")
|
||||
test_passed = True
|
||||
else:
|
||||
test_df = pd.read_csv(csv_paths[0])
|
||||
test_df = clean_df(test_df)
|
||||
new_trace_csv_path = os.path.join(trace_path, 'trace_new.csv')
|
||||
test_df.to_csv(new_trace_csv_path, index=False)
|
||||
test_df = pd.read_csv(new_trace_csv_path)
|
||||
trace_csv_path = glob.glob(path.join(trace_path, 'trace.csv'))
|
||||
trace_csv_path = trace_csv_path[0]
|
||||
trace_df = pd.read_csv(trace_csv_path)
|
||||
test_passed = test_df.equals(trace_df)
|
||||
if test_passed:
|
||||
screen.success("Passed successfully.")
|
||||
os.remove(new_trace_csv_path)
|
||||
test_passed = True
|
||||
else:
|
||||
screen.error("Trace test failed.", crash=False)
|
||||
if args.overwrite:
|
||||
os.remove(trace_csv_path)
|
||||
os.rename(new_trace_csv_path, trace_csv_path)
|
||||
screen.error("Overwriting old trace.", crash=False)
|
||||
else:
|
||||
screen.error("bcompare {} {}".format(trace_csv_path, new_trace_csv_path), crash=False)
|
||||
|
||||
shutil.rmtree(test_path)
|
||||
os.remove(log_file_name)
|
||||
return test_passed
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-t', '--trace',
|
||||
help="(flag) perform trace based testing",
|
||||
action='store_true')
|
||||
parser.add_argument('-p', '--preset',
|
||||
help="(string) Name of a preset to run (as configured in presets.py)",
|
||||
default=None,
|
||||
@@ -295,15 +203,11 @@ def main():
|
||||
parser.add_argument('-np', '--no_progress_bar',
|
||||
help="(flag) Don't print the progress bar (makes jenkins logs more readable)",
|
||||
action='store_true')
|
||||
parser.add_argument('-ow', '--overwrite',
|
||||
help="(flag) overwrite old trace with new ones in trace testing mode",
|
||||
action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.preset is not None:
|
||||
presets_lists = [args.preset]
|
||||
else:
|
||||
# presets_lists = list_all_classes_in_module(presets)
|
||||
presets_lists = [f[:-3] for f in os.listdir(os.path.join('rl_coach', 'presets')) if
|
||||
f[-3:] == '.py' and not f == '__init__.py']
|
||||
|
||||
@@ -329,26 +233,13 @@ def main():
|
||||
continue
|
||||
|
||||
preset_validation_params = preset.graph_manager.preset_validation_params
|
||||
if not args.trace and not preset_validation_params.test:
|
||||
if not preset_validation_params.test:
|
||||
continue
|
||||
|
||||
if args.trace:
|
||||
num_env_steps = preset_validation_params.trace_max_env_steps
|
||||
if preset_validation_params.trace_test_levels:
|
||||
for level in preset_validation_params.trace_test_levels:
|
||||
test_count += 1
|
||||
test_passed = perform_trace_based_tests(args, preset_name, num_env_steps, level)
|
||||
if not test_passed:
|
||||
fail_count += 1
|
||||
else:
|
||||
test_count += 1
|
||||
test_passed = perform_trace_based_tests(args, preset_name, num_env_steps)
|
||||
if not test_passed:
|
||||
fail_count += 1
|
||||
else:
|
||||
test_passed = perform_reward_based_tests(args, preset_validation_params, preset_name)
|
||||
if not test_passed:
|
||||
fail_count += 1
|
||||
test_count += 1
|
||||
test_passed = perform_reward_based_tests(args, preset_validation_params, preset_name)
|
||||
if not test_passed:
|
||||
fail_count += 1
|
||||
|
||||
screen.separator()
|
||||
if fail_count == 0:
|
||||
@@ -358,6 +249,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['DISABLE_MUJOCO_RENDERING'] = '1'
|
||||
main()
|
||||
del os.environ['DISABLE_MUJOCO_RENDERING']
|
||||
|
||||
Reference in New Issue
Block a user