1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

convert golden tests into pytest format

This commit is contained in:
zach dwiel
2018-10-10 16:26:15 -04:00
parent 787ab42578
commit 430ca198e5
2 changed files with 62 additions and 24 deletions

View File

@@ -41,7 +41,8 @@ integration_tests: build
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 -m pytest rl_coach/tests -m integration_test -n auto --tb=short
golden_tests: build
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/golden_tests.py
# ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/golden_tests.py
time ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 -m pytest rl_coach/tests -m golden_test -n auto
trace_tests: build
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/trace_tests.py -prl

View File

@@ -27,6 +27,7 @@ sys.path.append('.')
import numpy as np
import pandas as pd
import time
import pytest
# -*- coding: utf-8 -*-
from rl_coach.logger import screen
@@ -44,11 +45,11 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=100):
return csv_paths
def print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args):
def print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit):
percentage = int((100 * last_num_episodes) / preset_validation_params.max_episodes_to_achieve_reward)
sys.stdout.write("\rReward: ({}/{})".format(round(averaged_rewards[-1], 1),
preset_validation_params.min_reward_threshold))
sys.stdout.write(' Time (sec): ({}/{})'.format(round(time.time() - start_time, 2), args.time_limit))
sys.stdout.write(' Time (sec): ({}/{})'.format(round(time.time() - start_time, 2), time_limit))
sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes,
preset_validation_params.max_episodes_to_achieve_reward))
sys.stdout.write(
@@ -56,10 +57,56 @@ def print_progress(averaged_rewards, last_num_episodes, preset_validation_params
sys.stdout.flush()
def perform_reward_based_tests(args, preset_validation_params, preset_name):
def import_preset(preset_name):
return import_module('rl_coach.presets.{}'.format(preset_name))
def validation_params(preset_name):
return import_preset(preset_name).graph_manager.preset_validation_params
def all_presets():
return [
f[:-3] for f in os.listdir(os.path.join('rl_coach', 'presets'))
if f[-3:] == '.py' and not f == '__init__.py'
]
def importable(preset_name):
try:
import_preset(preset_name)
return True
except BaseException:
return False
def has_test_parameters(preset_name):
return bool(validation_params(preset_name).test)
def collect_presets():
for preset_name in all_presets():
# if it isn't importable, still include it so we can fail the test
if not importable(preset_name):
yield preset_name
# otherwise, make sure it has test parameters before including it
elif has_test_parameters(preset_name):
yield preset_name
print(list(collect_presets()))
@pytest.fixture(params=list(collect_presets()))
def preset_name(request):
return request.param
@pytest.mark.golden_test
def test_preset_reward(preset_name, no_progress_bar=False, time_limit=60 * 60):
preset_validation_params = validation_params(preset_name)
win_size = 10
test_name = '__test_reward'
test_name = '__test_reward_{}'.format(preset_name)
test_path = os.path.join('./experiments', test_name)
if path.exists(test_path):
shutil.rmtree(test_path)
@@ -106,11 +153,11 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
last_num_episodes = 0
if not args.no_progress_bar:
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args)
if not no_progress_bar:
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
while csv is None or (csv['Episode #'].values[
-1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < args.time_limit):
-1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit):
try:
csv = pd.read_csv(csv_path)
except:
@@ -130,8 +177,8 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
time.sleep(1)
continue
if not args.no_progress_bar:
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args)
if not no_progress_bar:
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
if csv['Episode #'].shape[0] - last_num_episodes <= 0:
continue
@@ -151,7 +198,7 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
if test_passed:
screen.success("Passed successfully")
else:
if time.time() - start_time > args.time_limit:
if time.time() - start_time > time_limit:
screen.error("Failed due to exceeding time limit", crash=False)
if args.verbose:
screen.error("command exitcode: {}".format(p.returncode), crash=False)
@@ -178,13 +225,6 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
return test_passed
def all_presets():
return [
f[:-3] for f in os.listdir(os.path.join('rl_coach', 'presets'))
if f[-3:] == '.py' and not f == '__init__.py'
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--preset',
@@ -228,20 +268,17 @@ def main():
if args.stop_after_first_failure and fail_count > 0:
break
if preset_name not in presets_to_ignore:
try:
preset = import_module('rl_coach.presets.{}'.format(preset_name))
except:
if not importable(preset_name):
screen.error("Failed to load preset <{}>".format(preset_name), crash=False)
fail_count += 1
test_count += 1
continue
preset_validation_params = preset.graph_manager.preset_validation_params
if not preset_validation_params.test:
if not has_test_parameters(preset_name):
continue
test_count += 1
test_passed = perform_reward_based_tests(args, preset_validation_params, preset_name)
test_passed = test_preset_reward(preset_name, args.no_progress_bar, args.time_limit)
if not test_passed:
fail_count += 1