mirror of
https://github.com/gryf/coach.git
synced 2026-02-14 21:15:53 +01:00
Trace tests update
This commit is contained in:
@@ -232,7 +232,7 @@ def perform_trace_based_tests(args, preset_name, num_env_steps, level=None):
|
||||
screen.error("command exitcode: {}".format(p.returncode), crash=False)
|
||||
screen.error(open(log_file_name).read(), crash=False)
|
||||
else:
|
||||
trace_path = os.path.join('./rl_coach', 'traces', preset_name + '_' + level if level else preset_name, '')
|
||||
trace_path = os.path.join('./rl_coach', 'traces', preset_name + '_' + level.replace(':', '_') if level else preset_name, '')
|
||||
if not os.path.exists(trace_path):
|
||||
screen.log('No trace found, creating new trace in: {}'.format(trace_path))
|
||||
os.makedirs(os.path.dirname(trace_path))
|
||||
@@ -254,6 +254,7 @@ def perform_trace_based_tests(args, preset_name, num_env_steps, level=None):
|
||||
if test_passed:
|
||||
screen.success("Passed successfully.")
|
||||
os.remove(new_trace_csv_path)
|
||||
test_passed = True
|
||||
else:
|
||||
screen.error("Trace test failed.", crash=False)
|
||||
if args.overwrite:
|
||||
@@ -322,8 +323,9 @@ def main():
|
||||
try:
|
||||
preset = import_module('rl_coach.presets.{}'.format(preset_name))
|
||||
except:
|
||||
if args.verbose:
|
||||
screen.error("Failed to load preset <{}>".format(preset_name), crash=False)
|
||||
screen.error("Failed to load preset <{}>".format(preset_name), crash=False)
|
||||
fail_count += 1
|
||||
test_count += 1
|
||||
continue
|
||||
|
||||
preset_validation_params = preset.graph_manager.preset_validation_params
|
||||
@@ -336,13 +338,17 @@ def main():
|
||||
for level in preset_validation_params.trace_test_levels:
|
||||
test_count += 1
|
||||
test_passed = perform_trace_based_tests(args, preset_name, num_env_steps, level)
|
||||
if not test_passed:
|
||||
fail_count += 1
|
||||
else:
|
||||
test_count += 1
|
||||
test_passed = perform_trace_based_tests(args, preset_name, num_env_steps)
|
||||
if not test_passed:
|
||||
fail_count += 1
|
||||
else:
|
||||
test_passed = perform_reward_based_tests(args, preset_validation_params, preset_name)
|
||||
if not test_passed:
|
||||
fail_count += 1
|
||||
if not test_passed:
|
||||
fail_count += 1
|
||||
|
||||
screen.separator()
|
||||
if fail_count == 0:
|
||||
@@ -352,4 +358,6 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['DISABLE_MUJOCO_RENDERING'] = '1'
|
||||
main()
|
||||
del os.environ['DISABLE_MUJOCO_RENDERING']
|
||||
|
||||
Reference in New Issue
Block a user