mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
enumerate each preset as its own test
This commit is contained in:
@@ -12,12 +12,23 @@ from subprocess import Popen, DEVNULL
|
|||||||
from rl_coach.logger import screen
|
from rl_coach.logger import screen
|
||||||
|
|
||||||
|
|
||||||
|
def all_presets():
|
||||||
|
result = []
|
||||||
|
for f in sorted(os.listdir('rl_coach/presets')):
|
||||||
|
if f.endswith('.py') and f != '__init__.py':
|
||||||
|
result.append(f.split('.')[0])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=all_presets())
|
||||||
|
def preset(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration_test
|
@pytest.mark.integration_test
|
||||||
def test_all_presets_are_running():
|
def test_preset_runs(preset):
|
||||||
# os.chdir("../../")
|
|
||||||
test_failed = False
|
test_failed = False
|
||||||
all_presets = sorted([f.split('.')[0] for f in os.listdir('rl_coach/presets') if f.endswith('.py') and f != '__init__.py'])
|
|
||||||
for preset in all_presets:
|
|
||||||
print("Testing preset {}".format(preset))
|
print("Testing preset {}".format(preset))
|
||||||
|
|
||||||
# TODO: this is a temporary workaround for presets which define more than a single available level.
|
# TODO: this is a temporary workaround for presets which define more than a single available level.
|
||||||
@@ -29,7 +40,10 @@ def test_all_presets_are_running():
|
|||||||
level = "inverted_pendulum"
|
level = "inverted_pendulum"
|
||||||
elif "ControlSuite" in preset:
|
elif "ControlSuite" in preset:
|
||||||
level = "pendulum:swingup"
|
level = "pendulum:swingup"
|
||||||
params = ["python3", "rl_coach/coach.py", "-p", preset, "-ns", "-e", ".test"]
|
|
||||||
|
experiment_name = ".test-" + preset
|
||||||
|
|
||||||
|
params = ["python3", "rl_coach/coach.py", "-p", preset, "-ns", "-e", experiment_name]
|
||||||
if level != "":
|
if level != "":
|
||||||
params += ["-lvl", level]
|
params += ["-lvl", level]
|
||||||
|
|
||||||
@@ -46,11 +60,7 @@ def test_all_presets_are_running():
|
|||||||
screen.error("{} failed".format(preset), crash=False)
|
screen.error("{} failed".format(preset), crash=False)
|
||||||
|
|
||||||
p.kill()
|
p.kill()
|
||||||
if os.path.exists("experiments/.test"):
|
if os.path.exists("experiments/" + experiment_name):
|
||||||
shutil.rmtree("experiments/.test")
|
shutil.rmtree("experiments/" + experiment_name)
|
||||||
|
|
||||||
assert not test_failed
|
assert not test_failed
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_all_presets_are_running()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user