1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

tests: added new checkpoint and functional tests (#265)

* added new tests
- test_preset_n_and_ew
- test_preset_n_and_ew_and_onnx

* code utils improvements (all utils)
* improve checkpoint_test
* new functionality for functional_test markers and presets lists
* removed special environment container
* add xfail to certain tests
This commit is contained in:
anabwan
2019-03-28 22:57:31 +02:00
committed by Scott Leishman
parent 310d31c227
commit 869bd421a3
6 changed files with 173 additions and 62 deletions

View File

@@ -24,6 +24,7 @@ import numpy as np
import pandas as pd
import rl_coach.tests.utils.args_utils as a_utils
import rl_coach.tests.utils.test_utils as test_utils
import rl_coach.tests.utils.presets_utils as p_utils
from rl_coach import checkpoint
from rl_coach.tests.utils.definitions import Definitions as Def
@@ -54,7 +55,8 @@ def test_get_checkpoint_state():
@pytest.mark.functional_test
def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
def test_restore_checkpoint(preset_args, clres, start_time=time.time(),
timeout=Def.TimeOuts.test_time_limit):
""" Create checkpoint and restore them in second run."""
def _create_cmd_and_run(flag):
@@ -71,6 +73,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
return p
p_valid_params = p_utils.validation_params(preset_args)
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])
# wait for checkpoint files
@@ -84,12 +87,11 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
if os.path.exists(checkpoint_test_dir):
shutil.rmtree(checkpoint_test_dir)
entities = a_utils.get_files_from_dir(checkpoint_dir)
assert a_utils.is_reward_reached(csv_path=csv_list[0],
p_valid_params=p_valid_params,
start_time=start_time, time_limit=timeout)
while not any("10_Step" in file for file in entities) and time.time() - \
start_time < Def.TimeOuts.test_time_limit:
entities = a_utils.get_files_from_dir(checkpoint_dir)
time.sleep(1)
entities = a_utils.get_files_from_dir(checkpoint_dir)
assert len(entities) > 0
assert "checkpoint" in entities
@@ -101,7 +103,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
csv = pd.read_csv(csv_list[0])
rewards = csv['Evaluation Reward'].values
rewards = rewards[~np.isnan(rewards)]
min_reward = np.amin(rewards)
max_reward = np.amax(rewards)
if os.path.isdir(checkpoint_dir):
shutil.copytree(exp_dir, checkpoint_test_dir)
@@ -119,7 +121,10 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
csv = pd.read_csv(new_csv_list[0])
res = csv['Episode Length'].values[-1]
assert res >= min_reward, \
Def.Consts.ASSERT_MSG.format(str(res) + ">=" + str(min_reward),
str(res) + " < " + str(min_reward))
assert res == max_reward, Def.Consts.ASSERT_MSG.format(str(max_reward),
str(res))
restore_cp_proc.kill()
test_folder = os.path.join(Def.Path.experiments, Def.Path.test_dir)
if os.path.exists(test_folder):
shutil.rmtree(test_folder)

View File

@@ -143,3 +143,113 @@ def test_preset_seed(preset_args_for_seed, clres, start_time=time.time(),
assert False
close_processes()
@pytest.mark.functional_test
def test_preset_n_and_ew(preset_args, clres, start_time=time.time(),
time_limit=Def.TimeOuts.test_time_limit):
"""
Test command arguments - check evaluation worker with number of workers
"""
ew_flag = ['-ew']
n_flag = ['-n', Def.Flags.enw]
p_valid_params = p_utils.validation_params(preset_args)
run_cmd = [
'python3', 'rl_coach/coach.py',
'-p', '{}'.format(preset_args),
'-e', '{}'.format("ExpName_" + preset_args),
]
# add flags to run command
test_ew_flag = a_utils.add_one_flag_value(flag=ew_flag)
test_n_flag = a_utils.add_one_flag_value(flag=n_flag)
run_cmd.extend(test_ew_flag)
run_cmd.extend(test_n_flag)
print(str(run_cmd))
proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
try:
a_utils.validate_arg_result(flag=test_ew_flag,
p_valid_params=p_valid_params, clres=clres,
process=proc, start_time=start_time,
timeout=time_limit)
a_utils.validate_arg_result(flag=test_n_flag,
p_valid_params=p_valid_params, clres=clres,
process=proc, start_time=start_time,
timeout=time_limit)
except AssertionError:
# close process once get assert false
proc.kill()
assert False
proc.kill()
@pytest.mark.functional_test
@pytest.mark.xfail(reason="https://github.com/NervanaSystems/coach/issues/257")
def test_preset_n_and_ew_and_onnx(preset_args, clres, start_time=time.time(),
time_limit=Def.TimeOuts.test_time_limit):
"""
Test command arguments - check evaluation worker, number of workers and
onnx.
"""
ew_flag = ['-ew']
n_flag = ['-n', Def.Flags.enw]
onnx_flag = ['-onnx']
s_flag = ['-s', Def.Flags.css]
p_valid_params = p_utils.validation_params(preset_args)
run_cmd = [
'python3', 'rl_coach/coach.py',
'-p', '{}'.format(preset_args),
'-e', '{}'.format("ExpName_" + preset_args),
]
# add flags to run command
test_ew_flag = a_utils.add_one_flag_value(flag=ew_flag)
test_n_flag = a_utils.add_one_flag_value(flag=n_flag)
test_onnx_flag = a_utils.add_one_flag_value(flag=onnx_flag)
test_s_flag = a_utils.add_one_flag_value(flag=s_flag)
run_cmd.extend(test_ew_flag)
run_cmd.extend(test_n_flag)
run_cmd.extend(test_onnx_flag)
run_cmd.extend(test_s_flag)
print(str(run_cmd))
proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
try:
# Check csv files has been created
a_utils.validate_arg_result(flag=test_ew_flag,
p_valid_params=p_valid_params, clres=clres,
process=proc, start_time=start_time,
timeout=time_limit)
# Check csv files created same as the number of the workers
a_utils.validate_arg_result(flag=test_n_flag,
p_valid_params=p_valid_params, clres=clres,
process=proc, start_time=start_time,
timeout=time_limit)
# Check checkpoint files
a_utils.validate_arg_result(flag=test_s_flag,
p_valid_params=p_valid_params, clres=clres,
process=proc, start_time=start_time,
timeout=time_limit)
# TODO: add onnx check; issue found #257
except AssertionError:
# close process once get assert false
proc.kill()
assert False
proc.kill()

View File

@@ -20,6 +20,7 @@ import signal
import time
import pandas as pd
import numpy as np
import pytest
from rl_coach.tests.utils.test_utils import get_csv_path, get_files_from_dir, \
find_string_in_logs
from rl_coach.tests.utils.definitions import Definitions as Def
@@ -56,7 +57,7 @@ def collect_preset_for_seed():
definitions (args_test under Presets).
:return: preset(s) list
"""
for pn in Def.Presets.seed_args_test:
for pn in Def.Presets.args_for_seed_test:
assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn)
yield pn
@@ -251,6 +252,8 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
-dg, --dump_gifs: Once selected, a new folder should be created in
experiment folder for gifs files.
"""
pytest.xfail(reason="GUI issue on CI")
csv_path = get_csv_path(clres)
assert len(csv_path) > 0, \
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
@@ -267,13 +270,14 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
# check if folder contain files
get_files_from_dir(dir_path=gifs_path)
# TODO: check if play window is opened
elif flag[0] == "-dm" or flag[0] == "--dump_mp4":
"""
-dm, --dump_mp4: Once selected, a new folder should be created in
experiment folder for videos files.
"""
pytest.xfail(reason="GUI issue on CI")
csv_path = get_csv_path(clres)
assert len(csv_path) > 0, \
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
@@ -290,7 +294,6 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
# check if folder contain files
get_files_from_dir(dir_path=videos_path)
# TODO: check if play window is opened
elif flag[0] == "--nocolor":
"""
@@ -363,7 +366,7 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
csv file is created.
"""
# wait until files created
csv_path = get_csv_path(clres=clres)
csv_path = get_csv_path(clres=clres, extra_tries=10)
assert len(csv_path) > 0, \
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
@@ -383,11 +386,14 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
# check heat-up value
results = []
while csv["In Heatup"].values[-1] == 1:
csv = pd.read_csv(csv_path[0])
last_step = csv["Total steps"].values
time.sleep(1)
results.append(last_step[-1])
if csv["In Heatup"].values[-1] == 0:
results.append(csv["Total steps"].values[-1])
else:
while csv["In Heatup"].values[-1] == 1:
csv = pd.read_csv(csv_path[0])
last_step = csv["Total steps"].values
results.append(last_step[-1])
time.sleep(1)
assert results[-1] >= Def.Consts.num_hs, \
Def.Consts.ASSERT_MSG.format("bigger than " + Def.Consts.num_hs,
@@ -475,3 +481,18 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
elif flag[0] == "-c" or flag[0] == "--use_cpu":
pass
elif flag[0] == "-n" or flag[0] == "--num_workers":
"""
-n, --num_workers: Once selected alone, check that csv created for each
worker, and check results.
"""
# wait until files created
csv_path = get_csv_path(clres=clres, extra_tries=20)
expected_files = int(flag[1])
assert len(csv_path) >= expected_files, \
Def.Consts.ASSERT_MSG.format(str(expected_files),
str(len(csv_path)))

View File

@@ -39,7 +39,7 @@ class Definitions:
enw = "num_workers"
fw_ten = "framework_tensorflow"
fw_mx = "framework_mxnet"
et = "rl_coach.environments.gym_environment:Atari"
# et = "rl_coach.environments.gym_environment:Atari" TODO
"""
Arguments that can be tested for python coach command
@@ -92,46 +92,16 @@ class Definitions:
]
# Preset for testing seed argument
seed_args_test = [
"Atari_A3C",
"Atari_A3C_LSTM",
"Atari_Bootstrapped_DQN",
"Atari_C51",
"Atari_DDQN",
"Atari_DQN_with_PER",
args_for_seed_test = [
"Atari_DQN",
"Atari_DQN_with_PER",
"Atari_Dueling_DDQN",
"Atari_Dueling_DDQN_with_PER_OpenAI",
"Atari_NStepQ",
"Atari_QR_DQN",
"Atari_Rainbow",
"Atari_UCB_with_Q_Ensembles",
"Doom_Basic_DQN",
"BitFlip_DQN",
"BitFlip_DQN_HER",
"CartPole_A3C",
"CartPole_ClippedPPO",
"CartPole_DFP",
"CartPole_DQN",
"CartPole_Dueling_DDQN",
"CartPole_NStepQ",
"CartPole_PAL",
"CartPole_PG",
"CARLA_Dueling_DDQN",
"ControlSuite_DDPG",
"ExplorationChain_Bootstrapped_DQN",
"ExplorationChain_Dueling_DDQN",
"ExplorationChain_UCB_Q_ensembles",
"Fetch_DDPG_HER_baselines",
"InvertedPendulum_PG",
"MontezumaRevenge_BC",
"Mujoco_A3C",
"Mujoco_A3C_LSTM",
"Mujoco_ClippedPPO",
"Mujoco_DDPG",
"Mujoco_NAF",
"Mujoco_PPO",
"Pendulum_HAC",
"Starcraft_CollectMinerals_A3C",
"Starcraft_CollectMinerals_Dueling_DDQN",
]

View File

@@ -16,6 +16,7 @@
"""Manage all preset"""
import os
import pytest
from importlib import import_module
from rl_coach.tests.utils.definitions import Definitions as Def
@@ -26,7 +27,13 @@ def import_preset(preset_name):
:param preset_name: preset name
:return: imported module
"""
return import_module('{}.presets.{}'.format(Def.GROUP_NAME, preset_name))
try:
module = import_module('{}.presets.{}'
.format(Def.GROUP_NAME, preset_name))
except:
pytest.skip("Can't import module: {}".format(preset_name))
return module
def validation_params(preset_name):

View File

@@ -84,7 +84,7 @@ def get_files_from_dir(dir_path):
:return: |list| return files in folder
"""
start_time = time.time()
entities = None
entities = []
while time.time() - start_time < Def.TimeOuts.wait_for_files:
# wait until logs created
if os.path.exists(dir_path):
@@ -118,17 +118,15 @@ def find_string_in_logs(log_path, str, timeout=Def.TimeOuts.wait_for_files,
if not os.path.exists(log_path):
return False
with open(log_path, 'r') as fr:
if str in fr.read():
return True
fr.close()
while time.time() - start_time < Def.TimeOuts.test_time_limit \
and wait_and_find:
while time.time() - start_time < Def.TimeOuts.test_time_limit:
with open(log_path, 'r') as fr:
if str in fr.read():
return True
fr.close()
if not wait_and_find:
break
return False