1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-14 21:15:53 +01:00

Trace tests update

This commit is contained in:
Shadi Endrawis
2018-08-20 13:01:17 +03:00
parent c1f428666e
commit 3abb6cd415
99 changed files with 12876 additions and 39 deletions

View File

@@ -232,7 +232,7 @@ def perform_trace_based_tests(args, preset_name, num_env_steps, level=None):
screen.error("command exitcode: {}".format(p.returncode), crash=False)
screen.error(open(log_file_name).read(), crash=False)
else:
trace_path = os.path.join('./rl_coach', 'traces', preset_name + '_' + level if level else preset_name, '')
trace_path = os.path.join('./rl_coach', 'traces', preset_name + '_' + level.replace(':', '_') if level else preset_name, '')
if not os.path.exists(trace_path):
screen.log('No trace found, creating new trace in: {}'.format(trace_path))
os.makedirs(os.path.dirname(trace_path))
@@ -254,6 +254,7 @@ def perform_trace_based_tests(args, preset_name, num_env_steps, level=None):
if test_passed:
screen.success("Passed successfully.")
os.remove(new_trace_csv_path)
test_passed = True
else:
screen.error("Trace test failed.", crash=False)
if args.overwrite:
@@ -322,8 +323,9 @@ def main():
try:
preset = import_module('rl_coach.presets.{}'.format(preset_name))
except:
if args.verbose:
screen.error("Failed to load preset <{}>".format(preset_name), crash=False)
screen.error("Failed to load preset <{}>".format(preset_name), crash=False)
fail_count += 1
test_count += 1
continue
preset_validation_params = preset.graph_manager.preset_validation_params
@@ -336,13 +338,17 @@ def main():
for level in preset_validation_params.trace_test_levels:
test_count += 1
test_passed = perform_trace_based_tests(args, preset_name, num_env_steps, level)
if not test_passed:
fail_count += 1
else:
test_count += 1
test_passed = perform_trace_based_tests(args, preset_name, num_env_steps)
if not test_passed:
fail_count += 1
else:
test_passed = perform_reward_based_tests(args, preset_validation_params, preset_name)
if not test_passed:
fail_count += 1
if not test_passed:
fail_count += 1
screen.separator()
if fail_count == 0:
@@ -352,4 +358,6 @@ def main():
if __name__ == '__main__':
os.environ['DISABLE_MUJOCO_RENDERING'] = '1'
main()
del os.environ['DISABLE_MUJOCO_RENDERING']