mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
tests: added new checkpoint and functional tests (#265)
* added new tests - test_preset_n_and_ew - test_preset_n_and_ew_and_onnx * code utils improvements (all utils) * improve checkpoint_test * new functionality for functional_test markers and presets lists * removed special environment container * add xfail to certain tests
This commit is contained in:
@@ -24,6 +24,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import rl_coach.tests.utils.args_utils as a_utils
|
import rl_coach.tests.utils.args_utils as a_utils
|
||||||
import rl_coach.tests.utils.test_utils as test_utils
|
import rl_coach.tests.utils.test_utils as test_utils
|
||||||
|
import rl_coach.tests.utils.presets_utils as p_utils
|
||||||
from rl_coach import checkpoint
|
from rl_coach import checkpoint
|
||||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||||
|
|
||||||
@@ -54,7 +55,8 @@ def test_get_checkpoint_state():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.functional_test
|
@pytest.mark.functional_test
|
||||||
def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
def test_restore_checkpoint(preset_args, clres, start_time=time.time(),
|
||||||
|
timeout=Def.TimeOuts.test_time_limit):
|
||||||
""" Create checkpoint and restore them in second run."""
|
""" Create checkpoint and restore them in second run."""
|
||||||
|
|
||||||
def _create_cmd_and_run(flag):
|
def _create_cmd_and_run(flag):
|
||||||
@@ -71,6 +73,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
|||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
p_valid_params = p_utils.validation_params(preset_args)
|
||||||
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])
|
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])
|
||||||
|
|
||||||
# wait for checkpoint files
|
# wait for checkpoint files
|
||||||
@@ -84,12 +87,11 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
|||||||
if os.path.exists(checkpoint_test_dir):
|
if os.path.exists(checkpoint_test_dir):
|
||||||
shutil.rmtree(checkpoint_test_dir)
|
shutil.rmtree(checkpoint_test_dir)
|
||||||
|
|
||||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
assert a_utils.is_reward_reached(csv_path=csv_list[0],
|
||||||
|
p_valid_params=p_valid_params,
|
||||||
|
start_time=start_time, time_limit=timeout)
|
||||||
|
|
||||||
while not any("10_Step" in file for file in entities) and time.time() - \
|
|
||||||
start_time < Def.TimeOuts.test_time_limit:
|
|
||||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
assert len(entities) > 0
|
assert len(entities) > 0
|
||||||
assert "checkpoint" in entities
|
assert "checkpoint" in entities
|
||||||
@@ -101,7 +103,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
|||||||
csv = pd.read_csv(csv_list[0])
|
csv = pd.read_csv(csv_list[0])
|
||||||
rewards = csv['Evaluation Reward'].values
|
rewards = csv['Evaluation Reward'].values
|
||||||
rewards = rewards[~np.isnan(rewards)]
|
rewards = rewards[~np.isnan(rewards)]
|
||||||
min_reward = np.amin(rewards)
|
max_reward = np.amax(rewards)
|
||||||
|
|
||||||
if os.path.isdir(checkpoint_dir):
|
if os.path.isdir(checkpoint_dir):
|
||||||
shutil.copytree(exp_dir, checkpoint_test_dir)
|
shutil.copytree(exp_dir, checkpoint_test_dir)
|
||||||
@@ -119,7 +121,10 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
|||||||
|
|
||||||
csv = pd.read_csv(new_csv_list[0])
|
csv = pd.read_csv(new_csv_list[0])
|
||||||
res = csv['Episode Length'].values[-1]
|
res = csv['Episode Length'].values[-1]
|
||||||
assert res >= min_reward, \
|
assert res == max_reward, Def.Consts.ASSERT_MSG.format(str(max_reward),
|
||||||
Def.Consts.ASSERT_MSG.format(str(res) + ">=" + str(min_reward),
|
str(res))
|
||||||
str(res) + " < " + str(min_reward))
|
|
||||||
restore_cp_proc.kill()
|
restore_cp_proc.kill()
|
||||||
|
|
||||||
|
test_folder = os.path.join(Def.Path.experiments, Def.Path.test_dir)
|
||||||
|
if os.path.exists(test_folder):
|
||||||
|
shutil.rmtree(test_folder)
|
||||||
|
|||||||
@@ -143,3 +143,113 @@ def test_preset_seed(preset_args_for_seed, clres, start_time=time.time(),
|
|||||||
assert False
|
assert False
|
||||||
|
|
||||||
close_processes()
|
close_processes()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.functional_test
|
||||||
|
def test_preset_n_and_ew(preset_args, clres, start_time=time.time(),
|
||||||
|
time_limit=Def.TimeOuts.test_time_limit):
|
||||||
|
"""
|
||||||
|
Test command arguments - check evaluation worker with number of workers
|
||||||
|
"""
|
||||||
|
|
||||||
|
ew_flag = ['-ew']
|
||||||
|
n_flag = ['-n', Def.Flags.enw]
|
||||||
|
p_valid_params = p_utils.validation_params(preset_args)
|
||||||
|
|
||||||
|
run_cmd = [
|
||||||
|
'python3', 'rl_coach/coach.py',
|
||||||
|
'-p', '{}'.format(preset_args),
|
||||||
|
'-e', '{}'.format("ExpName_" + preset_args),
|
||||||
|
]
|
||||||
|
|
||||||
|
# add flags to run command
|
||||||
|
test_ew_flag = a_utils.add_one_flag_value(flag=ew_flag)
|
||||||
|
test_n_flag = a_utils.add_one_flag_value(flag=n_flag)
|
||||||
|
run_cmd.extend(test_ew_flag)
|
||||||
|
run_cmd.extend(test_n_flag)
|
||||||
|
|
||||||
|
print(str(run_cmd))
|
||||||
|
|
||||||
|
proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
a_utils.validate_arg_result(flag=test_ew_flag,
|
||||||
|
p_valid_params=p_valid_params, clres=clres,
|
||||||
|
process=proc, start_time=start_time,
|
||||||
|
timeout=time_limit)
|
||||||
|
|
||||||
|
a_utils.validate_arg_result(flag=test_n_flag,
|
||||||
|
p_valid_params=p_valid_params, clres=clres,
|
||||||
|
process=proc, start_time=start_time,
|
||||||
|
timeout=time_limit)
|
||||||
|
except AssertionError:
|
||||||
|
# close process once get assert false
|
||||||
|
proc.kill()
|
||||||
|
assert False
|
||||||
|
|
||||||
|
proc.kill()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.functional_test
|
||||||
|
@pytest.mark.xfail(reason="https://github.com/NervanaSystems/coach/issues/257")
|
||||||
|
def test_preset_n_and_ew_and_onnx(preset_args, clres, start_time=time.time(),
|
||||||
|
time_limit=Def.TimeOuts.test_time_limit):
|
||||||
|
"""
|
||||||
|
Test command arguments - check evaluation worker, number of workers and
|
||||||
|
onnx.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ew_flag = ['-ew']
|
||||||
|
n_flag = ['-n', Def.Flags.enw]
|
||||||
|
onnx_flag = ['-onnx']
|
||||||
|
s_flag = ['-s', Def.Flags.css]
|
||||||
|
p_valid_params = p_utils.validation_params(preset_args)
|
||||||
|
|
||||||
|
run_cmd = [
|
||||||
|
'python3', 'rl_coach/coach.py',
|
||||||
|
'-p', '{}'.format(preset_args),
|
||||||
|
'-e', '{}'.format("ExpName_" + preset_args),
|
||||||
|
]
|
||||||
|
|
||||||
|
# add flags to run command
|
||||||
|
test_ew_flag = a_utils.add_one_flag_value(flag=ew_flag)
|
||||||
|
test_n_flag = a_utils.add_one_flag_value(flag=n_flag)
|
||||||
|
test_onnx_flag = a_utils.add_one_flag_value(flag=onnx_flag)
|
||||||
|
test_s_flag = a_utils.add_one_flag_value(flag=s_flag)
|
||||||
|
|
||||||
|
run_cmd.extend(test_ew_flag)
|
||||||
|
run_cmd.extend(test_n_flag)
|
||||||
|
run_cmd.extend(test_onnx_flag)
|
||||||
|
run_cmd.extend(test_s_flag)
|
||||||
|
|
||||||
|
print(str(run_cmd))
|
||||||
|
|
||||||
|
proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check csv files has been created
|
||||||
|
a_utils.validate_arg_result(flag=test_ew_flag,
|
||||||
|
p_valid_params=p_valid_params, clres=clres,
|
||||||
|
process=proc, start_time=start_time,
|
||||||
|
timeout=time_limit)
|
||||||
|
|
||||||
|
# Check csv files created same as the number of the workers
|
||||||
|
a_utils.validate_arg_result(flag=test_n_flag,
|
||||||
|
p_valid_params=p_valid_params, clres=clres,
|
||||||
|
process=proc, start_time=start_time,
|
||||||
|
timeout=time_limit)
|
||||||
|
|
||||||
|
# Check checkpoint files
|
||||||
|
a_utils.validate_arg_result(flag=test_s_flag,
|
||||||
|
p_valid_params=p_valid_params, clres=clres,
|
||||||
|
process=proc, start_time=start_time,
|
||||||
|
timeout=time_limit)
|
||||||
|
|
||||||
|
# TODO: add onnx check; issue found #257
|
||||||
|
|
||||||
|
except AssertionError:
|
||||||
|
# close process once get assert false
|
||||||
|
proc.kill()
|
||||||
|
assert False
|
||||||
|
|
||||||
|
proc.kill()
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import signal
|
|||||||
import time
|
import time
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
from rl_coach.tests.utils.test_utils import get_csv_path, get_files_from_dir, \
|
from rl_coach.tests.utils.test_utils import get_csv_path, get_files_from_dir, \
|
||||||
find_string_in_logs
|
find_string_in_logs
|
||||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||||
@@ -56,7 +57,7 @@ def collect_preset_for_seed():
|
|||||||
definitions (args_test under Presets).
|
definitions (args_test under Presets).
|
||||||
:return: preset(s) list
|
:return: preset(s) list
|
||||||
"""
|
"""
|
||||||
for pn in Def.Presets.seed_args_test:
|
for pn in Def.Presets.args_for_seed_test:
|
||||||
assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn)
|
assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn)
|
||||||
yield pn
|
yield pn
|
||||||
|
|
||||||
@@ -251,6 +252,8 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
|||||||
-dg, --dump_gifs: Once selected, a new folder should be created in
|
-dg, --dump_gifs: Once selected, a new folder should be created in
|
||||||
experiment folder for gifs files.
|
experiment folder for gifs files.
|
||||||
"""
|
"""
|
||||||
|
pytest.xfail(reason="GUI issue on CI")
|
||||||
|
|
||||||
csv_path = get_csv_path(clres)
|
csv_path = get_csv_path(clres)
|
||||||
assert len(csv_path) > 0, \
|
assert len(csv_path) > 0, \
|
||||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||||
@@ -267,13 +270,14 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
|||||||
|
|
||||||
# check if folder contain files
|
# check if folder contain files
|
||||||
get_files_from_dir(dir_path=gifs_path)
|
get_files_from_dir(dir_path=gifs_path)
|
||||||
# TODO: check if play window is opened
|
|
||||||
|
|
||||||
elif flag[0] == "-dm" or flag[0] == "--dump_mp4":
|
elif flag[0] == "-dm" or flag[0] == "--dump_mp4":
|
||||||
"""
|
"""
|
||||||
-dm, --dump_mp4: Once selected, a new folder should be created in
|
-dm, --dump_mp4: Once selected, a new folder should be created in
|
||||||
experiment folder for videos files.
|
experiment folder for videos files.
|
||||||
"""
|
"""
|
||||||
|
pytest.xfail(reason="GUI issue on CI")
|
||||||
|
|
||||||
csv_path = get_csv_path(clres)
|
csv_path = get_csv_path(clres)
|
||||||
assert len(csv_path) > 0, \
|
assert len(csv_path) > 0, \
|
||||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||||
@@ -290,7 +294,6 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
|||||||
|
|
||||||
# check if folder contain files
|
# check if folder contain files
|
||||||
get_files_from_dir(dir_path=videos_path)
|
get_files_from_dir(dir_path=videos_path)
|
||||||
# TODO: check if play window is opened
|
|
||||||
|
|
||||||
elif flag[0] == "--nocolor":
|
elif flag[0] == "--nocolor":
|
||||||
"""
|
"""
|
||||||
@@ -363,7 +366,7 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
|||||||
csv file is created.
|
csv file is created.
|
||||||
"""
|
"""
|
||||||
# wait until files created
|
# wait until files created
|
||||||
csv_path = get_csv_path(clres=clres)
|
csv_path = get_csv_path(clres=clres, extra_tries=10)
|
||||||
assert len(csv_path) > 0, \
|
assert len(csv_path) > 0, \
|
||||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||||
|
|
||||||
@@ -383,11 +386,14 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
|||||||
|
|
||||||
# check heat-up value
|
# check heat-up value
|
||||||
results = []
|
results = []
|
||||||
|
if csv["In Heatup"].values[-1] == 0:
|
||||||
|
results.append(csv["Total steps"].values[-1])
|
||||||
|
else:
|
||||||
while csv["In Heatup"].values[-1] == 1:
|
while csv["In Heatup"].values[-1] == 1:
|
||||||
csv = pd.read_csv(csv_path[0])
|
csv = pd.read_csv(csv_path[0])
|
||||||
last_step = csv["Total steps"].values
|
last_step = csv["Total steps"].values
|
||||||
time.sleep(1)
|
|
||||||
results.append(last_step[-1])
|
results.append(last_step[-1])
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
assert results[-1] >= Def.Consts.num_hs, \
|
assert results[-1] >= Def.Consts.num_hs, \
|
||||||
Def.Consts.ASSERT_MSG.format("bigger than " + Def.Consts.num_hs,
|
Def.Consts.ASSERT_MSG.format("bigger than " + Def.Consts.num_hs,
|
||||||
@@ -475,3 +481,18 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
|||||||
|
|
||||||
elif flag[0] == "-c" or flag[0] == "--use_cpu":
|
elif flag[0] == "-c" or flag[0] == "--use_cpu":
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
elif flag[0] == "-n" or flag[0] == "--num_workers":
|
||||||
|
|
||||||
|
"""
|
||||||
|
-n, --num_workers: Once selected alone, check that csv created for each
|
||||||
|
worker, and check results.
|
||||||
|
"""
|
||||||
|
# wait until files created
|
||||||
|
csv_path = get_csv_path(clres=clres, extra_tries=20)
|
||||||
|
|
||||||
|
expected_files = int(flag[1])
|
||||||
|
assert len(csv_path) >= expected_files, \
|
||||||
|
Def.Consts.ASSERT_MSG.format(str(expected_files),
|
||||||
|
str(len(csv_path)))
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class Definitions:
|
|||||||
enw = "num_workers"
|
enw = "num_workers"
|
||||||
fw_ten = "framework_tensorflow"
|
fw_ten = "framework_tensorflow"
|
||||||
fw_mx = "framework_mxnet"
|
fw_mx = "framework_mxnet"
|
||||||
et = "rl_coach.environments.gym_environment:Atari"
|
# et = "rl_coach.environments.gym_environment:Atari" TODO
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Arguments that can be tested for python coach command
|
Arguments that can be tested for python coach command
|
||||||
@@ -92,46 +92,16 @@ class Definitions:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Preset for testing seed argument
|
# Preset for testing seed argument
|
||||||
seed_args_test = [
|
args_for_seed_test = [
|
||||||
"Atari_A3C",
|
|
||||||
"Atari_A3C_LSTM",
|
|
||||||
"Atari_Bootstrapped_DQN",
|
|
||||||
"Atari_C51",
|
|
||||||
"Atari_DDQN",
|
|
||||||
"Atari_DQN_with_PER",
|
|
||||||
"Atari_DQN",
|
"Atari_DQN",
|
||||||
"Atari_DQN_with_PER",
|
"Doom_Basic_DQN",
|
||||||
"Atari_Dueling_DDQN",
|
|
||||||
"Atari_Dueling_DDQN_with_PER_OpenAI",
|
|
||||||
"Atari_NStepQ",
|
|
||||||
"Atari_QR_DQN",
|
|
||||||
"Atari_Rainbow",
|
|
||||||
"Atari_UCB_with_Q_Ensembles",
|
|
||||||
"BitFlip_DQN",
|
"BitFlip_DQN",
|
||||||
"BitFlip_DQN_HER",
|
|
||||||
"CartPole_A3C",
|
|
||||||
"CartPole_ClippedPPO",
|
|
||||||
"CartPole_DFP",
|
|
||||||
"CartPole_DQN",
|
"CartPole_DQN",
|
||||||
"CartPole_Dueling_DDQN",
|
"CARLA_Dueling_DDQN",
|
||||||
"CartPole_NStepQ",
|
|
||||||
"CartPole_PAL",
|
|
||||||
"CartPole_PG",
|
|
||||||
"ControlSuite_DDPG",
|
"ControlSuite_DDPG",
|
||||||
"ExplorationChain_Bootstrapped_DQN",
|
|
||||||
"ExplorationChain_Dueling_DDQN",
|
"ExplorationChain_Dueling_DDQN",
|
||||||
"ExplorationChain_UCB_Q_ensembles",
|
|
||||||
"Fetch_DDPG_HER_baselines",
|
"Fetch_DDPG_HER_baselines",
|
||||||
"InvertedPendulum_PG",
|
|
||||||
"MontezumaRevenge_BC",
|
|
||||||
"Mujoco_A3C",
|
|
||||||
"Mujoco_A3C_LSTM",
|
|
||||||
"Mujoco_ClippedPPO",
|
"Mujoco_ClippedPPO",
|
||||||
"Mujoco_DDPG",
|
|
||||||
"Mujoco_NAF",
|
|
||||||
"Mujoco_PPO",
|
|
||||||
"Pendulum_HAC",
|
|
||||||
"Starcraft_CollectMinerals_A3C",
|
|
||||||
"Starcraft_CollectMinerals_Dueling_DDQN",
|
"Starcraft_CollectMinerals_Dueling_DDQN",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
"""Manage all preset"""
|
"""Manage all preset"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import pytest
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||||
|
|
||||||
@@ -26,7 +27,13 @@ def import_preset(preset_name):
|
|||||||
:param preset_name: preset name
|
:param preset_name: preset name
|
||||||
:return: imported module
|
:return: imported module
|
||||||
"""
|
"""
|
||||||
return import_module('{}.presets.{}'.format(Def.GROUP_NAME, preset_name))
|
try:
|
||||||
|
module = import_module('{}.presets.{}'
|
||||||
|
.format(Def.GROUP_NAME, preset_name))
|
||||||
|
except:
|
||||||
|
pytest.skip("Can't import module: {}".format(preset_name))
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def validation_params(preset_name):
|
def validation_params(preset_name):
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ def get_files_from_dir(dir_path):
|
|||||||
:return: |list| return files in folder
|
:return: |list| return files in folder
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
entities = None
|
entities = []
|
||||||
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
||||||
# wait until logs created
|
# wait until logs created
|
||||||
if os.path.exists(dir_path):
|
if os.path.exists(dir_path):
|
||||||
@@ -118,17 +118,15 @@ def find_string_in_logs(log_path, str, timeout=Def.TimeOuts.wait_for_files,
|
|||||||
if not os.path.exists(log_path):
|
if not os.path.exists(log_path):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
while time.time() - start_time < Def.TimeOuts.test_time_limit:
|
||||||
with open(log_path, 'r') as fr:
|
with open(log_path, 'r') as fr:
|
||||||
if str in fr.read():
|
if str in fr.read():
|
||||||
return True
|
return True
|
||||||
fr.close()
|
fr.close()
|
||||||
|
|
||||||
while time.time() - start_time < Def.TimeOuts.test_time_limit \
|
if not wait_and_find:
|
||||||
and wait_and_find:
|
break
|
||||||
with open(log_path, 'r') as fr:
|
|
||||||
if str in fr.read():
|
|
||||||
return True
|
|
||||||
fr.close()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user