mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
tests: added new checkpoint and functional tests (#265)
* added new tests - test_preset_n_and_ew - test_preset_n_and_ew_and_onnx * code utils improvements (all utils) * improve checkpoint_test * new functionality for functional_test markers and presets lists * removed special environment container * add xfail to certain tests
This commit is contained in:
@@ -24,6 +24,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import rl_coach.tests.utils.args_utils as a_utils
|
||||
import rl_coach.tests.utils.test_utils as test_utils
|
||||
import rl_coach.tests.utils.presets_utils as p_utils
|
||||
from rl_coach import checkpoint
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
|
||||
@@ -54,7 +55,8 @@ def test_get_checkpoint_state():
|
||||
|
||||
|
||||
@pytest.mark.functional_test
|
||||
def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
def test_restore_checkpoint(preset_args, clres, start_time=time.time(),
|
||||
timeout=Def.TimeOuts.test_time_limit):
|
||||
""" Create checkpoint and restore them in second run."""
|
||||
|
||||
def _create_cmd_and_run(flag):
|
||||
@@ -71,6 +73,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
|
||||
return p
|
||||
|
||||
p_valid_params = p_utils.validation_params(preset_args)
|
||||
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])
|
||||
|
||||
# wait for checkpoint files
|
||||
@@ -84,12 +87,11 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
if os.path.exists(checkpoint_test_dir):
|
||||
shutil.rmtree(checkpoint_test_dir)
|
||||
|
||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||
assert a_utils.is_reward_reached(csv_path=csv_list[0],
|
||||
p_valid_params=p_valid_params,
|
||||
start_time=start_time, time_limit=timeout)
|
||||
|
||||
while not any("10_Step" in file for file in entities) and time.time() - \
|
||||
start_time < Def.TimeOuts.test_time_limit:
|
||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||
time.sleep(1)
|
||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||
|
||||
assert len(entities) > 0
|
||||
assert "checkpoint" in entities
|
||||
@@ -101,7 +103,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
csv = pd.read_csv(csv_list[0])
|
||||
rewards = csv['Evaluation Reward'].values
|
||||
rewards = rewards[~np.isnan(rewards)]
|
||||
min_reward = np.amin(rewards)
|
||||
max_reward = np.amax(rewards)
|
||||
|
||||
if os.path.isdir(checkpoint_dir):
|
||||
shutil.copytree(exp_dir, checkpoint_test_dir)
|
||||
@@ -119,7 +121,10 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
|
||||
csv = pd.read_csv(new_csv_list[0])
|
||||
res = csv['Episode Length'].values[-1]
|
||||
assert res >= min_reward, \
|
||||
Def.Consts.ASSERT_MSG.format(str(res) + ">=" + str(min_reward),
|
||||
str(res) + " < " + str(min_reward))
|
||||
assert res == max_reward, Def.Consts.ASSERT_MSG.format(str(max_reward),
|
||||
str(res))
|
||||
restore_cp_proc.kill()
|
||||
|
||||
test_folder = os.path.join(Def.Path.experiments, Def.Path.test_dir)
|
||||
if os.path.exists(test_folder):
|
||||
shutil.rmtree(test_folder)
|
||||
|
||||
@@ -143,3 +143,113 @@ def test_preset_seed(preset_args_for_seed, clres, start_time=time.time(),
|
||||
assert False
|
||||
|
||||
close_processes()
|
||||
|
||||
|
||||
@pytest.mark.functional_test
|
||||
def test_preset_n_and_ew(preset_args, clres, start_time=time.time(),
|
||||
time_limit=Def.TimeOuts.test_time_limit):
|
||||
"""
|
||||
Test command arguments - check evaluation worker with number of workers
|
||||
"""
|
||||
|
||||
ew_flag = ['-ew']
|
||||
n_flag = ['-n', Def.Flags.enw]
|
||||
p_valid_params = p_utils.validation_params(preset_args)
|
||||
|
||||
run_cmd = [
|
||||
'python3', 'rl_coach/coach.py',
|
||||
'-p', '{}'.format(preset_args),
|
||||
'-e', '{}'.format("ExpName_" + preset_args),
|
||||
]
|
||||
|
||||
# add flags to run command
|
||||
test_ew_flag = a_utils.add_one_flag_value(flag=ew_flag)
|
||||
test_n_flag = a_utils.add_one_flag_value(flag=n_flag)
|
||||
run_cmd.extend(test_ew_flag)
|
||||
run_cmd.extend(test_n_flag)
|
||||
|
||||
print(str(run_cmd))
|
||||
|
||||
proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
|
||||
|
||||
try:
|
||||
a_utils.validate_arg_result(flag=test_ew_flag,
|
||||
p_valid_params=p_valid_params, clres=clres,
|
||||
process=proc, start_time=start_time,
|
||||
timeout=time_limit)
|
||||
|
||||
a_utils.validate_arg_result(flag=test_n_flag,
|
||||
p_valid_params=p_valid_params, clres=clres,
|
||||
process=proc, start_time=start_time,
|
||||
timeout=time_limit)
|
||||
except AssertionError:
|
||||
# close process once get assert false
|
||||
proc.kill()
|
||||
assert False
|
||||
|
||||
proc.kill()
|
||||
|
||||
|
||||
@pytest.mark.functional_test
|
||||
@pytest.mark.xfail(reason="https://github.com/NervanaSystems/coach/issues/257")
|
||||
def test_preset_n_and_ew_and_onnx(preset_args, clres, start_time=time.time(),
|
||||
time_limit=Def.TimeOuts.test_time_limit):
|
||||
"""
|
||||
Test command arguments - check evaluation worker, number of workers and
|
||||
onnx.
|
||||
"""
|
||||
|
||||
ew_flag = ['-ew']
|
||||
n_flag = ['-n', Def.Flags.enw]
|
||||
onnx_flag = ['-onnx']
|
||||
s_flag = ['-s', Def.Flags.css]
|
||||
p_valid_params = p_utils.validation_params(preset_args)
|
||||
|
||||
run_cmd = [
|
||||
'python3', 'rl_coach/coach.py',
|
||||
'-p', '{}'.format(preset_args),
|
||||
'-e', '{}'.format("ExpName_" + preset_args),
|
||||
]
|
||||
|
||||
# add flags to run command
|
||||
test_ew_flag = a_utils.add_one_flag_value(flag=ew_flag)
|
||||
test_n_flag = a_utils.add_one_flag_value(flag=n_flag)
|
||||
test_onnx_flag = a_utils.add_one_flag_value(flag=onnx_flag)
|
||||
test_s_flag = a_utils.add_one_flag_value(flag=s_flag)
|
||||
|
||||
run_cmd.extend(test_ew_flag)
|
||||
run_cmd.extend(test_n_flag)
|
||||
run_cmd.extend(test_onnx_flag)
|
||||
run_cmd.extend(test_s_flag)
|
||||
|
||||
print(str(run_cmd))
|
||||
|
||||
proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
|
||||
|
||||
try:
|
||||
# Check csv files has been created
|
||||
a_utils.validate_arg_result(flag=test_ew_flag,
|
||||
p_valid_params=p_valid_params, clres=clres,
|
||||
process=proc, start_time=start_time,
|
||||
timeout=time_limit)
|
||||
|
||||
# Check csv files created same as the number of the workers
|
||||
a_utils.validate_arg_result(flag=test_n_flag,
|
||||
p_valid_params=p_valid_params, clres=clres,
|
||||
process=proc, start_time=start_time,
|
||||
timeout=time_limit)
|
||||
|
||||
# Check checkpoint files
|
||||
a_utils.validate_arg_result(flag=test_s_flag,
|
||||
p_valid_params=p_valid_params, clres=clres,
|
||||
process=proc, start_time=start_time,
|
||||
timeout=time_limit)
|
||||
|
||||
# TODO: add onnx check; issue found #257
|
||||
|
||||
except AssertionError:
|
||||
# close process once get assert false
|
||||
proc.kill()
|
||||
assert False
|
||||
|
||||
proc.kill()
|
||||
|
||||
@@ -20,6 +20,7 @@ import signal
|
||||
import time
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from rl_coach.tests.utils.test_utils import get_csv_path, get_files_from_dir, \
|
||||
find_string_in_logs
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
@@ -56,7 +57,7 @@ def collect_preset_for_seed():
|
||||
definitions (args_test under Presets).
|
||||
:return: preset(s) list
|
||||
"""
|
||||
for pn in Def.Presets.seed_args_test:
|
||||
for pn in Def.Presets.args_for_seed_test:
|
||||
assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn)
|
||||
yield pn
|
||||
|
||||
@@ -251,6 +252,8 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
||||
-dg, --dump_gifs: Once selected, a new folder should be created in
|
||||
experiment folder for gifs files.
|
||||
"""
|
||||
pytest.xfail(reason="GUI issue on CI")
|
||||
|
||||
csv_path = get_csv_path(clres)
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||
@@ -267,13 +270,14 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
||||
|
||||
# check if folder contain files
|
||||
get_files_from_dir(dir_path=gifs_path)
|
||||
# TODO: check if play window is opened
|
||||
|
||||
elif flag[0] == "-dm" or flag[0] == "--dump_mp4":
|
||||
"""
|
||||
-dm, --dump_mp4: Once selected, a new folder should be created in
|
||||
experiment folder for videos files.
|
||||
"""
|
||||
pytest.xfail(reason="GUI issue on CI")
|
||||
|
||||
csv_path = get_csv_path(clres)
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||
@@ -290,7 +294,6 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
||||
|
||||
# check if folder contain files
|
||||
get_files_from_dir(dir_path=videos_path)
|
||||
# TODO: check if play window is opened
|
||||
|
||||
elif flag[0] == "--nocolor":
|
||||
"""
|
||||
@@ -363,7 +366,7 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
||||
csv file is created.
|
||||
"""
|
||||
# wait until files created
|
||||
csv_path = get_csv_path(clres=clres)
|
||||
csv_path = get_csv_path(clres=clres, extra_tries=10)
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||
|
||||
@@ -383,11 +386,14 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
||||
|
||||
# check heat-up value
|
||||
results = []
|
||||
while csv["In Heatup"].values[-1] == 1:
|
||||
csv = pd.read_csv(csv_path[0])
|
||||
last_step = csv["Total steps"].values
|
||||
time.sleep(1)
|
||||
results.append(last_step[-1])
|
||||
if csv["In Heatup"].values[-1] == 0:
|
||||
results.append(csv["Total steps"].values[-1])
|
||||
else:
|
||||
while csv["In Heatup"].values[-1] == 1:
|
||||
csv = pd.read_csv(csv_path[0])
|
||||
last_step = csv["Total steps"].values
|
||||
results.append(last_step[-1])
|
||||
time.sleep(1)
|
||||
|
||||
assert results[-1] >= Def.Consts.num_hs, \
|
||||
Def.Consts.ASSERT_MSG.format("bigger than " + Def.Consts.num_hs,
|
||||
@@ -475,3 +481,18 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
||||
|
||||
elif flag[0] == "-c" or flag[0] == "--use_cpu":
|
||||
pass
|
||||
|
||||
elif flag[0] == "-n" or flag[0] == "--num_workers":
|
||||
|
||||
"""
|
||||
-n, --num_workers: Once selected alone, check that csv created for each
|
||||
worker, and check results.
|
||||
"""
|
||||
# wait until files created
|
||||
csv_path = get_csv_path(clres=clres, extra_tries=20)
|
||||
|
||||
expected_files = int(flag[1])
|
||||
assert len(csv_path) >= expected_files, \
|
||||
Def.Consts.ASSERT_MSG.format(str(expected_files),
|
||||
str(len(csv_path)))
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class Definitions:
|
||||
enw = "num_workers"
|
||||
fw_ten = "framework_tensorflow"
|
||||
fw_mx = "framework_mxnet"
|
||||
et = "rl_coach.environments.gym_environment:Atari"
|
||||
# et = "rl_coach.environments.gym_environment:Atari" TODO
|
||||
|
||||
"""
|
||||
Arguments that can be tested for python coach command
|
||||
@@ -92,46 +92,16 @@ class Definitions:
|
||||
]
|
||||
|
||||
# Preset for testing seed argument
|
||||
seed_args_test = [
|
||||
"Atari_A3C",
|
||||
"Atari_A3C_LSTM",
|
||||
"Atari_Bootstrapped_DQN",
|
||||
"Atari_C51",
|
||||
"Atari_DDQN",
|
||||
"Atari_DQN_with_PER",
|
||||
args_for_seed_test = [
|
||||
"Atari_DQN",
|
||||
"Atari_DQN_with_PER",
|
||||
"Atari_Dueling_DDQN",
|
||||
"Atari_Dueling_DDQN_with_PER_OpenAI",
|
||||
"Atari_NStepQ",
|
||||
"Atari_QR_DQN",
|
||||
"Atari_Rainbow",
|
||||
"Atari_UCB_with_Q_Ensembles",
|
||||
"Doom_Basic_DQN",
|
||||
"BitFlip_DQN",
|
||||
"BitFlip_DQN_HER",
|
||||
"CartPole_A3C",
|
||||
"CartPole_ClippedPPO",
|
||||
"CartPole_DFP",
|
||||
"CartPole_DQN",
|
||||
"CartPole_Dueling_DDQN",
|
||||
"CartPole_NStepQ",
|
||||
"CartPole_PAL",
|
||||
"CartPole_PG",
|
||||
"CARLA_Dueling_DDQN",
|
||||
"ControlSuite_DDPG",
|
||||
"ExplorationChain_Bootstrapped_DQN",
|
||||
"ExplorationChain_Dueling_DDQN",
|
||||
"ExplorationChain_UCB_Q_ensembles",
|
||||
"Fetch_DDPG_HER_baselines",
|
||||
"InvertedPendulum_PG",
|
||||
"MontezumaRevenge_BC",
|
||||
"Mujoco_A3C",
|
||||
"Mujoco_A3C_LSTM",
|
||||
"Mujoco_ClippedPPO",
|
||||
"Mujoco_DDPG",
|
||||
"Mujoco_NAF",
|
||||
"Mujoco_PPO",
|
||||
"Pendulum_HAC",
|
||||
"Starcraft_CollectMinerals_A3C",
|
||||
"Starcraft_CollectMinerals_Dueling_DDQN",
|
||||
]
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
"""Manage all preset"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from importlib import import_module
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
|
||||
@@ -26,7 +27,13 @@ def import_preset(preset_name):
|
||||
:param preset_name: preset name
|
||||
:return: imported module
|
||||
"""
|
||||
return import_module('{}.presets.{}'.format(Def.GROUP_NAME, preset_name))
|
||||
try:
|
||||
module = import_module('{}.presets.{}'
|
||||
.format(Def.GROUP_NAME, preset_name))
|
||||
except:
|
||||
pytest.skip("Can't import module: {}".format(preset_name))
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def validation_params(preset_name):
|
||||
|
||||
@@ -84,7 +84,7 @@ def get_files_from_dir(dir_path):
|
||||
:return: |list| return files in folder
|
||||
"""
|
||||
start_time = time.time()
|
||||
entities = None
|
||||
entities = []
|
||||
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
||||
# wait until logs created
|
||||
if os.path.exists(dir_path):
|
||||
@@ -118,17 +118,15 @@ def find_string_in_logs(log_path, str, timeout=Def.TimeOuts.wait_for_files,
|
||||
if not os.path.exists(log_path):
|
||||
return False
|
||||
|
||||
with open(log_path, 'r') as fr:
|
||||
if str in fr.read():
|
||||
return True
|
||||
fr.close()
|
||||
|
||||
while time.time() - start_time < Def.TimeOuts.test_time_limit \
|
||||
and wait_and_find:
|
||||
while time.time() - start_time < Def.TimeOuts.test_time_limit:
|
||||
with open(log_path, 'r') as fr:
|
||||
if str in fr.read():
|
||||
return True
|
||||
fr.close()
|
||||
|
||||
if not wait_and_find:
|
||||
break
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user