mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
convert golden tests into pytest format
This commit is contained in:
@@ -41,7 +41,8 @@ integration_tests: build
|
|||||||
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 -m pytest rl_coach/tests -m integration_test -n auto --tb=short
|
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 -m pytest rl_coach/tests -m integration_test -n auto --tb=short
|
||||||
|
|
||||||
golden_tests: build
|
golden_tests: build
|
||||||
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/golden_tests.py
|
# ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/golden_tests.py
|
||||||
|
time ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 -m pytest rl_coach/tests -m golden_test -n auto
|
||||||
|
|
||||||
trace_tests: build
|
trace_tests: build
|
||||||
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/trace_tests.py -prl
|
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/tests/trace_tests.py -prl
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ sys.path.append('.')
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import time
|
import time
|
||||||
|
import pytest
|
||||||
|
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from rl_coach.logger import screen
|
from rl_coach.logger import screen
|
||||||
@@ -44,11 +45,11 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=100):
|
|||||||
return csv_paths
|
return csv_paths
|
||||||
|
|
||||||
|
|
||||||
def print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args):
|
def print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit):
|
||||||
percentage = int((100 * last_num_episodes) / preset_validation_params.max_episodes_to_achieve_reward)
|
percentage = int((100 * last_num_episodes) / preset_validation_params.max_episodes_to_achieve_reward)
|
||||||
sys.stdout.write("\rReward: ({}/{})".format(round(averaged_rewards[-1], 1),
|
sys.stdout.write("\rReward: ({}/{})".format(round(averaged_rewards[-1], 1),
|
||||||
preset_validation_params.min_reward_threshold))
|
preset_validation_params.min_reward_threshold))
|
||||||
sys.stdout.write(' Time (sec): ({}/{})'.format(round(time.time() - start_time, 2), args.time_limit))
|
sys.stdout.write(' Time (sec): ({}/{})'.format(round(time.time() - start_time, 2), time_limit))
|
||||||
sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes,
|
sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes,
|
||||||
preset_validation_params.max_episodes_to_achieve_reward))
|
preset_validation_params.max_episodes_to_achieve_reward))
|
||||||
sys.stdout.write(
|
sys.stdout.write(
|
||||||
@@ -56,10 +57,56 @@ def print_progress(averaged_rewards, last_num_episodes, preset_validation_params
|
|||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
def import_preset(preset_name):
|
||||||
|
return import_module('rl_coach.presets.{}'.format(preset_name))
|
||||||
|
|
||||||
|
|
||||||
|
def validation_params(preset_name):
|
||||||
|
return import_preset(preset_name).graph_manager.preset_validation_params
|
||||||
|
|
||||||
|
|
||||||
|
def all_presets():
|
||||||
|
return [
|
||||||
|
f[:-3] for f in os.listdir(os.path.join('rl_coach', 'presets'))
|
||||||
|
if f[-3:] == '.py' and not f == '__init__.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def importable(preset_name):
|
||||||
|
try:
|
||||||
|
import_preset(preset_name)
|
||||||
|
return True
|
||||||
|
except BaseException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def has_test_parameters(preset_name):
|
||||||
|
return bool(validation_params(preset_name).test)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_presets():
|
||||||
|
for preset_name in all_presets():
|
||||||
|
# if it isn't importable, still include it so we can fail the test
|
||||||
|
if not importable(preset_name):
|
||||||
|
yield preset_name
|
||||||
|
# otherwise, make sure it has test parameters before including it
|
||||||
|
elif has_test_parameters(preset_name):
|
||||||
|
yield preset_name
|
||||||
|
|
||||||
|
|
||||||
|
print(list(collect_presets()))
|
||||||
|
@pytest.fixture(params=list(collect_presets()))
|
||||||
|
def preset_name(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.golden_test
|
||||||
|
def test_preset_reward(preset_name, no_progress_bar=False, time_limit=60 * 60):
|
||||||
|
preset_validation_params = validation_params(preset_name)
|
||||||
|
|
||||||
win_size = 10
|
win_size = 10
|
||||||
|
|
||||||
test_name = '__test_reward'
|
test_name = '__test_reward_{}'.format(preset_name)
|
||||||
test_path = os.path.join('./experiments', test_name)
|
test_path = os.path.join('./experiments', test_name)
|
||||||
if path.exists(test_path):
|
if path.exists(test_path):
|
||||||
shutil.rmtree(test_path)
|
shutil.rmtree(test_path)
|
||||||
@@ -106,11 +153,11 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
|||||||
|
|
||||||
last_num_episodes = 0
|
last_num_episodes = 0
|
||||||
|
|
||||||
if not args.no_progress_bar:
|
if not no_progress_bar:
|
||||||
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args)
|
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
|
||||||
|
|
||||||
while csv is None or (csv['Episode #'].values[
|
while csv is None or (csv['Episode #'].values[
|
||||||
-1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < args.time_limit):
|
-1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit):
|
||||||
try:
|
try:
|
||||||
csv = pd.read_csv(csv_path)
|
csv = pd.read_csv(csv_path)
|
||||||
except:
|
except:
|
||||||
@@ -130,8 +177,8 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not args.no_progress_bar:
|
if not no_progress_bar:
|
||||||
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, args)
|
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
|
||||||
|
|
||||||
if csv['Episode #'].shape[0] - last_num_episodes <= 0:
|
if csv['Episode #'].shape[0] - last_num_episodes <= 0:
|
||||||
continue
|
continue
|
||||||
@@ -151,7 +198,7 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
|||||||
if test_passed:
|
if test_passed:
|
||||||
screen.success("Passed successfully")
|
screen.success("Passed successfully")
|
||||||
else:
|
else:
|
||||||
if time.time() - start_time > args.time_limit:
|
if time.time() - start_time > time_limit:
|
||||||
screen.error("Failed due to exceeding time limit", crash=False)
|
screen.error("Failed due to exceeding time limit", crash=False)
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
screen.error("command exitcode: {}".format(p.returncode), crash=False)
|
screen.error("command exitcode: {}".format(p.returncode), crash=False)
|
||||||
@@ -178,13 +225,6 @@ def perform_reward_based_tests(args, preset_validation_params, preset_name):
|
|||||||
return test_passed
|
return test_passed
|
||||||
|
|
||||||
|
|
||||||
def all_presets():
|
|
||||||
return [
|
|
||||||
f[:-3] for f in os.listdir(os.path.join('rl_coach', 'presets'))
|
|
||||||
if f[-3:] == '.py' and not f == '__init__.py'
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-p', '--preset',
|
parser.add_argument('-p', '--preset',
|
||||||
@@ -228,20 +268,17 @@ def main():
|
|||||||
if args.stop_after_first_failure and fail_count > 0:
|
if args.stop_after_first_failure and fail_count > 0:
|
||||||
break
|
break
|
||||||
if preset_name not in presets_to_ignore:
|
if preset_name not in presets_to_ignore:
|
||||||
try:
|
if not importable(preset_name):
|
||||||
preset = import_module('rl_coach.presets.{}'.format(preset_name))
|
|
||||||
except:
|
|
||||||
screen.error("Failed to load preset <{}>".format(preset_name), crash=False)
|
screen.error("Failed to load preset <{}>".format(preset_name), crash=False)
|
||||||
fail_count += 1
|
fail_count += 1
|
||||||
test_count += 1
|
test_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
preset_validation_params = preset.graph_manager.preset_validation_params
|
if not has_test_parameters(preset_name):
|
||||||
if not preset_validation_params.test:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
test_count += 1
|
test_count += 1
|
||||||
test_passed = perform_reward_based_tests(args, preset_validation_params, preset_name)
|
test_passed = test_preset_reward(preset_name, args.no_progress_bar, args.time_limit)
|
||||||
if not test_passed:
|
if not test_passed:
|
||||||
fail_count += 1
|
fail_count += 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user