mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
convert golden tests into pytest format
This commit is contained in:
@@ -41,7 +41,8 @@ integration_tests: build
|
||||
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 -m pytest rl_coach/tests -m integration_test -n auto --tb=short
|
||||
|
||||
golden_tests: build
|
||||
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/golden_tests.py
|
||||
# ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/golden_tests.py
|
||||
time ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 -m pytest rl_coach/tests -m golden_test -n auto
|
||||
|
||||
trace_tests: build
|
||||
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/trace_tests.py -prl
|
||||
|
||||
@@ -27,6 +27,7 @@ sys.path.append('.')
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import time
|
||||
import pytest
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
from rl_coach.logger import screen
|
||||
@@ -44,11 +45,11 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=100):
|
||||
return csv_paths
|
||||
|
||||
|
||||
def print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args):
|
||||
def print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit):
|
||||
percentage = int((100 * last_num_episodes) / preset_validation_params.max_episodes_to_achieve_reward)
|
||||
sys.stdout.write("\rReward: ({}/{})".format(round(averaged_rewards[-1], 1),
|
||||
preset_validation_params.min_reward_threshold))
|
||||
sys.stdout.write(' Time (sec): ({}/{})'.format(round(time.time() - start_time, 2), args.time_limit))
|
||||
sys.stdout.write(' Time (sec): ({}/{})'.format(round(time.time() - start_time, 2), time_limit))
|
||||
sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes,
|
||||
preset_validation_params.max_episodes_to_achieve_reward))
|
||||
sys.stdout.write(
|
||||
@@ -56,10 +57,56 @@ def print_progress(averaged_rewards, last_num_episodes, preset_validation_params
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
||||
def import_preset(preset_name):
|
||||
return import_module('rl_coach.presets.{}'.format(preset_name))
|
||||
|
||||
|
||||
def validation_params(preset_name):
|
||||
return import_preset(preset_name).graph_manager.preset_validation_params
|
||||
|
||||
|
||||
def all_presets():
|
||||
return [
|
||||
f[:-3] for f in os.listdir(os.path.join('rl_coach', 'presets'))
|
||||
if f[-3:] == '.py' and not f == '__init__.py'
|
||||
]
|
||||
|
||||
|
||||
def importable(preset_name):
|
||||
try:
|
||||
import_preset(preset_name)
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
|
||||
def has_test_parameters(preset_name):
|
||||
return bool(validation_params(preset_name).test)
|
||||
|
||||
|
||||
def collect_presets():
|
||||
for preset_name in all_presets():
|
||||
# if it isn't importable, still include it so we can fail the test
|
||||
if not importable(preset_name):
|
||||
yield preset_name
|
||||
# otherwise, make sure it has test parameters before including it
|
||||
elif has_test_parameters(preset_name):
|
||||
yield preset_name
|
||||
|
||||
|
||||
print(list(collect_presets()))
|
||||
@pytest.fixture(params=list(collect_presets()))
|
||||
def preset_name(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.mark.golden_test
|
||||
def test_preset_reward(preset_name, no_progress_bar=False, time_limit=60 * 60):
|
||||
preset_validation_params = validation_params(preset_name)
|
||||
|
||||
win_size = 10
|
||||
|
||||
test_name = '__test_reward'
|
||||
test_name = '__test_reward_{}'.format(preset_name)
|
||||
test_path = os.path.join('./experiments', test_name)
|
||||
if path.exists(test_path):
|
||||
shutil.rmtree(test_path)
|
||||
@@ -106,11 +153,11 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
||||
|
||||
last_num_episodes = 0
|
||||
|
||||
if not args.no_progress_bar:
|
||||
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args)
|
||||
if not no_progress_bar:
|
||||
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
|
||||
|
||||
while csv is None or (csv['Episode #'].values[
|
||||
-1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < args.time_limit):
|
||||
-1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit):
|
||||
try:
|
||||
csv = pd.read_csv(csv_path)
|
||||
except:
|
||||
@@ -130,8 +177,8 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if not args.no_progress_bar:
|
||||
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args)
|
||||
if not no_progress_bar:
|
||||
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
|
||||
|
||||
if csv['Episode #'].shape[0] - last_num_episodes <= 0:
|
||||
continue
|
||||
@@ -151,7 +198,7 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
||||
if test_passed:
|
||||
screen.success("Passed successfully")
|
||||
else:
|
||||
if time.time() - start_time > args.time_limit:
|
||||
if time.time() - start_time > time_limit:
|
||||
screen.error("Failed due to exceeding time limit", crash=False)
|
||||
if args.verbose:
|
||||
screen.error("command exitcode: {}".format(p.returncode), crash=False)
|
||||
@@ -178,13 +225,6 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
||||
return test_passed
|
||||
|
||||
|
||||
def all_presets():
|
||||
return [
|
||||
f[:-3] for f in os.listdir(os.path.join('rl_coach', 'presets'))
|
||||
if f[-3:] == '.py' and not f == '__init__.py'
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--preset',
|
||||
@@ -228,20 +268,17 @@ def main():
|
||||
if args.stop_after_first_failure and fail_count > 0:
|
||||
break
|
||||
if preset_name not in presets_to_ignore:
|
||||
try:
|
||||
preset = import_module('rl_coach.presets.{}'.format(preset_name))
|
||||
except:
|
||||
if not importable(preset_name):
|
||||
screen.error("Failed to load preset <{}>".format(preset_name), crash=False)
|
||||
fail_count += 1
|
||||
test_count += 1
|
||||
continue
|
||||
|
||||
preset_validation_params = preset.graph_manager.preset_validation_params
|
||||
if not preset_validation_params.test:
|
||||
if not has_test_parameters(preset_name):
|
||||
continue
|
||||
|
||||
test_count += 1
|
||||
test_passed = perform_reward_based_tests(args, preset_validation_params, preset_name)
|
||||
test_passed = test_preset_reward(preset_name, args.no_progress_bar, args.time_limit)
|
||||
if not test_passed:
|
||||
fail_count += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user