1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

tests: added new checkpoint and functional tests (#265)

* added new tests
- test_preset_n_and_ew
- test_preset_n_and_ew_and_onnx

* code utils improvements (all utils)
* improve checkpoint_test
* new functionality for functional_test markers and presets lists
* removed special environment container
* add xfail to certain tests
This commit is contained in:
anabwan
2019-03-28 22:57:31 +02:00
committed by Scott Leishman
parent 310d31c227
commit 869bd421a3
6 changed files with 173 additions and 62 deletions

View File

@@ -24,6 +24,7 @@ import numpy as np
import pandas as pd
import rl_coach.tests.utils.args_utils as a_utils
import rl_coach.tests.utils.test_utils as test_utils
import rl_coach.tests.utils.presets_utils as p_utils
from rl_coach import checkpoint
from rl_coach.tests.utils.definitions import Definitions as Def
@@ -54,7 +55,8 @@ def test_get_checkpoint_state():
@pytest.mark.functional_test
def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
def test_restore_checkpoint(preset_args, clres, start_time=time.time(),
timeout=Def.TimeOuts.test_time_limit):
""" Create checkpoint and restore them in second run."""
def _create_cmd_and_run(flag):
@@ -71,6 +73,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
return p
p_valid_params = p_utils.validation_params(preset_args)
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])
# wait for checkpoint files
@@ -84,12 +87,11 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
if os.path.exists(checkpoint_test_dir):
shutil.rmtree(checkpoint_test_dir)
entities = a_utils.get_files_from_dir(checkpoint_dir)
assert a_utils.is_reward_reached(csv_path=csv_list[0],
p_valid_params=p_valid_params,
start_time=start_time, time_limit=timeout)
while not any("10_Step" in file for file in entities) and time.time() - \
start_time < Def.TimeOuts.test_time_limit:
entities = a_utils.get_files_from_dir(checkpoint_dir)
time.sleep(1)
entities = a_utils.get_files_from_dir(checkpoint_dir)
assert len(entities) > 0
assert "checkpoint" in entities
@@ -101,7 +103,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
csv = pd.read_csv(csv_list[0])
rewards = csv['Evaluation Reward'].values
rewards = rewards[~np.isnan(rewards)]
min_reward = np.amin(rewards)
max_reward = np.amax(rewards)
if os.path.isdir(checkpoint_dir):
shutil.copytree(exp_dir, checkpoint_test_dir)
@@ -119,7 +121,10 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
csv = pd.read_csv(new_csv_list[0])
res = csv['Episode Length'].values[-1]
assert res >= min_reward, \
Def.Consts.ASSERT_MSG.format(str(res) + ">=" + str(min_reward),
str(res) + " < " + str(min_reward))
assert res == max_reward, Def.Consts.ASSERT_MSG.format(str(max_reward),
str(res))
restore_cp_proc.kill()
test_folder = os.path.join(Def.Path.experiments, Def.Path.test_dir)
if os.path.exists(test_folder):
shutil.rmtree(test_folder)