diff --git a/docker/Makefile b/docker/Makefile index bd5cbe8..3fd4b71 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -41,7 +41,8 @@ integration_tests: build ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 -m pytest rl_coach/tests -m integration_test -n auto --tb=short golden_tests: build - ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/golden_tests.py + # ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/golden_tests.py + time ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 -m pytest rl_coach/tests -m golden_test -n auto trace_tests: build ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/trace_tests.py -prl diff --git a/rl_coach/tests/test_golden.py b/rl_coach/tests/test_golden.py index 654da6f..d2ce972 100644 --- a/rl_coach/tests/test_golden.py +++ b/rl_coach/tests/test_golden.py @@ -27,6 +27,7 @@ sys.path.append('.') import numpy as np import pandas as pd import time +import pytest # -*- coding: utf-8 -*- from rl_coach.logger import screen @@ -44,11 +45,11 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=100): return csv_paths -def print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args): +def print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit): percentage = int((100 * last_num_episodes) / preset_validation_params.max_episodes_to_achieve_reward) sys.stdout.write("\rReward: ({}/{})".format(round(averaged_rewards[-1], 1), preset_validation_params.min_reward_threshold)) - sys.stdout.write(' Time (sec): ({}/{})'.format(round(time.time() - start_time, 2), args.time_limit)) + sys.stdout.write(' Time (sec): ({}/{})'.format(round(time.time() - start_time, 2), time_limit)) sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes, preset_validation_params.max_episodes_to_achieve_reward)) sys.stdout.write( @@ -56,10 +57,56 @@ def print_progress(averaged_rewards, last_num_episodes, preset_validation_params sys.stdout.flush() -def perform_reward_based_tests(args, preset_validation_params, preset_name): +def import_preset(preset_name): + return import_module('rl_coach.presets.{}'.format(preset_name)) + + +def validation_params(preset_name): + return import_preset(preset_name).graph_manager.preset_validation_params + + +def all_presets(): + return [ + f[:-3] for f in os.listdir(os.path.join('rl_coach', 'presets')) + if f[-3:] == '.py' and not f == '__init__.py' + ] + + +def importable(preset_name): + try: + import_preset(preset_name) + return True + except BaseException: + return False + + +def has_test_parameters(preset_name): + return bool(validation_params(preset_name).test) + + +def collect_presets(): + for preset_name in all_presets(): + # if it isn't importable, still include it so we can fail the test + if not importable(preset_name): + yield preset_name + # otherwise, make sure it has test parameters before including it + elif has_test_parameters(preset_name): + yield preset_name + + +print(list(collect_presets())) +@pytest.fixture(params=list(collect_presets())) +def preset_name(request): + return request.param + + +@pytest.mark.golden_test +def test_preset_reward(preset_name, no_progress_bar=False, time_limit=60 * 60): + preset_validation_params = validation_params(preset_name) + win_size = 10 - test_name = '__test_reward' + test_name = '__test_reward_{}'.format(preset_name) test_path = os.path.join('./experiments', test_name) if path.exists(test_path): shutil.rmtree(test_path) @@ -106,11 +153,11 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name): last_num_episodes = 0 - if not args.no_progress_bar: - print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args) + if not no_progress_bar: + print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit) while csv is None or (csv['Episode #'].values[ - -1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < args.time_limit): + -1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit): try: csv = pd.read_csv(csv_path) except: @@ -130,8 +177,8 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name): time.sleep(1) continue - if not args.no_progress_bar: - print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args) + if not no_progress_bar: + print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit) if csv['Episode #'].shape[0] - last_num_episodes <= 0: continue @@ -151,7 +198,7 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name): if test_passed: screen.success("Passed successfully") else: - if time.time() - start_time > args.time_limit: + if time.time() - start_time > time_limit: screen.error("Failed due to exceeding time limit", crash=False) if args.verbose: screen.error("command exitcode: {}".format(p.returncode), crash=False) @@ -178,13 +225,6 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name): return test_passed -def all_presets(): - return [ - f[:-3] for f in os.listdir(os.path.join('rl_coach', 'presets')) - if f[-3:] == '.py' and not f == '__init__.py' - ] - - def main(): parser = argparse.ArgumentParser() parser.add_argument('-p', '--preset', @@ -228,20 +268,17 @@ def main(): if args.stop_after_first_failure and fail_count > 0: break if preset_name not in presets_to_ignore: - try: - preset = import_module('rl_coach.presets.{}'.format(preset_name)) - except: + if not importable(preset_name): screen.error("Failed to load preset <{}>".format(preset_name), crash=False) fail_count += 1 test_count += 1 continue - preset_validation_params = preset.graph_manager.preset_validation_params - if not preset_validation_params.test: + if not has_test_parameters(preset_name): continue test_count += 1 - test_passed = perform_reward_based_tests(args, preset_validation_params, preset_name) + test_passed = test_preset_reward(preset_name, args.no_progress_bar, args.time_limit) if not test_passed: fail_count += 1