mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
tests: added new checkpoint and functional tests (#265)
* added new tests - test_preset_n_and_ew - test_preset_n_and_ew_and_onnx * code utils improvements (all utils) * improve checkpoint_test * new functionality for functional_test markers and presets lists * removed special environment container * add xfail to certain tests
This commit is contained in:
@@ -24,6 +24,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import rl_coach.tests.utils.args_utils as a_utils
|
||||
import rl_coach.tests.utils.test_utils as test_utils
|
||||
import rl_coach.tests.utils.presets_utils as p_utils
|
||||
from rl_coach import checkpoint
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
|
||||
@@ -54,7 +55,8 @@ def test_get_checkpoint_state():
|
||||
|
||||
|
||||
@pytest.mark.functional_test
|
||||
def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
def test_restore_checkpoint(preset_args, clres, start_time=time.time(),
|
||||
timeout=Def.TimeOuts.test_time_limit):
|
||||
""" Create checkpoint and restore them in second run."""
|
||||
|
||||
def _create_cmd_and_run(flag):
|
||||
@@ -71,6 +73,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
|
||||
return p
|
||||
|
||||
p_valid_params = p_utils.validation_params(preset_args)
|
||||
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])
|
||||
|
||||
# wait for checkpoint files
|
||||
@@ -84,12 +87,11 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
if os.path.exists(checkpoint_test_dir):
|
||||
shutil.rmtree(checkpoint_test_dir)
|
||||
|
||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||
assert a_utils.is_reward_reached(csv_path=csv_list[0],
|
||||
p_valid_params=p_valid_params,
|
||||
start_time=start_time, time_limit=timeout)
|
||||
|
||||
while not any("10_Step" in file for file in entities) and time.time() - \
|
||||
start_time < Def.TimeOuts.test_time_limit:
|
||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||
time.sleep(1)
|
||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||
|
||||
assert len(entities) > 0
|
||||
assert "checkpoint" in entities
|
||||
@@ -101,7 +103,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
csv = pd.read_csv(csv_list[0])
|
||||
rewards = csv['Evaluation Reward'].values
|
||||
rewards = rewards[~np.isnan(rewards)]
|
||||
min_reward = np.amin(rewards)
|
||||
max_reward = np.amax(rewards)
|
||||
|
||||
if os.path.isdir(checkpoint_dir):
|
||||
shutil.copytree(exp_dir, checkpoint_test_dir)
|
||||
@@ -119,7 +121,10 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
|
||||
csv = pd.read_csv(new_csv_list[0])
|
||||
res = csv['Episode Length'].values[-1]
|
||||
assert res >= min_reward, \
|
||||
Def.Consts.ASSERT_MSG.format(str(res) + ">=" + str(min_reward),
|
||||
str(res) + " < " + str(min_reward))
|
||||
assert res == max_reward, Def.Consts.ASSERT_MSG.format(str(max_reward),
|
||||
str(res))
|
||||
restore_cp_proc.kill()
|
||||
|
||||
test_folder = os.path.join(Def.Path.experiments, Def.Path.test_dir)
|
||||
if os.path.exists(test_folder):
|
||||
shutil.rmtree(test_folder)
|
||||
|
||||
Reference in New Issue
Block a user