diff --git a/rl_coach/tests/test_checkpoint.py b/rl_coach/tests/test_checkpoint.py index b1ecc5b..b72e999 100644 --- a/rl_coach/tests/test_checkpoint.py +++ b/rl_coach/tests/test_checkpoint.py @@ -24,6 +24,7 @@ import numpy as np import pandas as pd import rl_coach.tests.utils.args_utils as a_utils import rl_coach.tests.utils.test_utils as test_utils +import rl_coach.tests.utils.presets_utils as p_utils from rl_coach import checkpoint from rl_coach.tests.utils.definitions import Definitions as Def @@ -54,7 +55,8 @@ def test_get_checkpoint_state(): @pytest.mark.functional_test -def test_restore_checkpoint(preset_args, clres, start_time=time.time()): +def test_restore_checkpoint(preset_args, clres, start_time=time.time(), + timeout=Def.TimeOuts.test_time_limit): """ Create checkpoint and restore them in second run.""" def _create_cmd_and_run(flag): @@ -71,6 +73,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()): return p + p_valid_params = p_utils.validation_params(preset_args) create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5']) # wait for checkpoint files @@ -84,12 +87,11 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()): if os.path.exists(checkpoint_test_dir): shutil.rmtree(checkpoint_test_dir) - entities = a_utils.get_files_from_dir(checkpoint_dir) + assert a_utils.is_reward_reached(csv_path=csv_list[0], + p_valid_params=p_valid_params, + start_time=start_time, time_limit=timeout) - while not any("10_Step" in file for file in entities) and time.time() - \ - start_time < Def.TimeOuts.test_time_limit: - entities = a_utils.get_files_from_dir(checkpoint_dir) - time.sleep(1) + entities = a_utils.get_files_from_dir(checkpoint_dir) assert len(entities) > 0 assert "checkpoint" in entities @@ -101,7 +103,7 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()): csv = pd.read_csv(csv_list[0]) rewards = csv['Evaluation Reward'].values rewards = rewards[~np.isnan(rewards)] - min_reward = np.amin(rewards) + max_reward = np.amax(rewards) if os.path.isdir(checkpoint_dir): shutil.copytree(exp_dir, checkpoint_test_dir) @@ -119,7 +121,10 @@ def test_restore_checkpoint(preset_args, clres, start_time=time.time()): csv = pd.read_csv(new_csv_list[0]) res = csv['Episode Length'].values[-1] - assert res >= min_reward, \ - Def.Consts.ASSERT_MSG.format(str(res) + ">=" + str(min_reward), - str(res) + " < " + str(min_reward)) + assert res == max_reward, Def.Consts.ASSERT_MSG.format(str(max_reward), + str(res)) restore_cp_proc.kill() + + test_folder = os.path.join(Def.Path.experiments, Def.Path.test_dir) + if os.path.exists(test_folder): + shutil.rmtree(test_folder) diff --git a/rl_coach/tests/test_coach_args.py b/rl_coach/tests/test_coach_args.py index 23423fc..0f46940 100644 --- a/rl_coach/tests/test_coach_args.py +++ b/rl_coach/tests/test_coach_args.py @@ -143,3 +143,113 @@ def test_preset_seed(preset_args_for_seed, clres, start_time=time.time(), assert False close_processes() + + +@pytest.mark.functional_test +def test_preset_n_and_ew(preset_args, clres, start_time=time.time(), + time_limit=Def.TimeOuts.test_time_limit): + """ + Test command arguments - check evaluation worker with number of workers + """ + + ew_flag = ['-ew'] + n_flag = ['-n', Def.Flags.enw] + p_valid_params = p_utils.validation_params(preset_args) + + run_cmd = [ + 'python3', 'rl_coach/coach.py', + '-p', '{}'.format(preset_args), + '-e', '{}'.format("ExpName_" + preset_args), + ] + + # add flags to run command + test_ew_flag = a_utils.add_one_flag_value(flag=ew_flag) + test_n_flag = a_utils.add_one_flag_value(flag=n_flag) + run_cmd.extend(test_ew_flag) + run_cmd.extend(test_n_flag) + + print(str(run_cmd)) + + proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout) + + try: + a_utils.validate_arg_result(flag=test_ew_flag, + p_valid_params=p_valid_params, clres=clres, + process=proc, start_time=start_time, + timeout=time_limit) + + a_utils.validate_arg_result(flag=test_n_flag, + p_valid_params=p_valid_params, clres=clres, + process=proc, start_time=start_time, + timeout=time_limit) + except AssertionError: + # close process once get assert false + proc.kill() + assert False + + proc.kill() + + +@pytest.mark.functional_test +@pytest.mark.xfail(reason="https://github.com/NervanaSystems/coach/issues/257") +def test_preset_n_and_ew_and_onnx(preset_args, clres, start_time=time.time(), + time_limit=Def.TimeOuts.test_time_limit): + """ + Test command arguments - check evaluation worker, number of workers and + onnx. + """ + + ew_flag = ['-ew'] + n_flag = ['-n', Def.Flags.enw] + onnx_flag = ['-onnx'] + s_flag = ['-s', Def.Flags.css] + p_valid_params = p_utils.validation_params(preset_args) + + run_cmd = [ + 'python3', 'rl_coach/coach.py', + '-p', '{}'.format(preset_args), + '-e', '{}'.format("ExpName_" + preset_args), + ] + + # add flags to run command + test_ew_flag = a_utils.add_one_flag_value(flag=ew_flag) + test_n_flag = a_utils.add_one_flag_value(flag=n_flag) + test_onnx_flag = a_utils.add_one_flag_value(flag=onnx_flag) + test_s_flag = a_utils.add_one_flag_value(flag=s_flag) + + run_cmd.extend(test_ew_flag) + run_cmd.extend(test_n_flag) + run_cmd.extend(test_onnx_flag) + run_cmd.extend(test_s_flag) + + print(str(run_cmd)) + + proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout) + + try: + # Check csv files has been created + a_utils.validate_arg_result(flag=test_ew_flag, + p_valid_params=p_valid_params, clres=clres, + process=proc, start_time=start_time, + timeout=time_limit) + + # Check csv files created same as the number of the workers + a_utils.validate_arg_result(flag=test_n_flag, + p_valid_params=p_valid_params, clres=clres, + process=proc, start_time=start_time, + timeout=time_limit) + + # Check checkpoint files + a_utils.validate_arg_result(flag=test_s_flag, + p_valid_params=p_valid_params, clres=clres, + process=proc, start_time=start_time, + timeout=time_limit) + + # TODO: add onnx check; issue found #257 + + except AssertionError: + # close process once get assert false + proc.kill() + assert False + + proc.kill() diff --git a/rl_coach/tests/utils/args_utils.py b/rl_coach/tests/utils/args_utils.py index cf26012..324a633 100644 --- a/rl_coach/tests/utils/args_utils.py +++ b/rl_coach/tests/utils/args_utils.py @@ -20,6 +20,7 @@ import signal import time import pandas as pd import numpy as np +import pytest from rl_coach.tests.utils.test_utils import get_csv_path, get_files_from_dir, \ find_string_in_logs from rl_coach.tests.utils.definitions import Definitions as Def @@ -56,7 +57,7 @@ def collect_preset_for_seed(): definitions (args_test under Presets). :return: preset(s) list """ - for pn in Def.Presets.seed_args_test: + for pn in Def.Presets.args_for_seed_test: assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn) yield pn @@ -251,6 +252,8 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, -dg, --dump_gifs: Once selected, a new folder should be created in experiment folder for gifs files. """ + pytest.xfail(reason="GUI issue on CI") + csv_path = get_csv_path(clres) assert len(csv_path) > 0, \ Def.Consts.ASSERT_MSG.format("path not found", csv_path) @@ -267,13 +270,14 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, # check if folder contain files get_files_from_dir(dir_path=gifs_path) - # TODO: check if play window is opened elif flag[0] == "-dm" or flag[0] == "--dump_mp4": """ -dm, --dump_mp4: Once selected, a new folder should be created in experiment folder for videos files. """ + pytest.xfail(reason="GUI issue on CI") + csv_path = get_csv_path(clres) assert len(csv_path) > 0, \ Def.Consts.ASSERT_MSG.format("path not found", csv_path) @@ -290,7 +294,6 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, # check if folder contain files get_files_from_dir(dir_path=videos_path) - # TODO: check if play window is opened elif flag[0] == "--nocolor": """ @@ -363,7 +366,7 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, csv file is created. """ # wait until files created - csv_path = get_csv_path(clres=clres) + csv_path = get_csv_path(clres=clres, extra_tries=10) assert len(csv_path) > 0, \ Def.Consts.ASSERT_MSG.format("path not found", csv_path) @@ -383,11 +386,14 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, # check heat-up value results = [] - while csv["In Heatup"].values[-1] == 1: - csv = pd.read_csv(csv_path[0]) - last_step = csv["Total steps"].values - time.sleep(1) - results.append(last_step[-1]) + if csv["In Heatup"].values[-1] == 0: + results.append(csv["Total steps"].values[-1]) + else: + while csv["In Heatup"].values[-1] == 1: + csv = pd.read_csv(csv_path[0]) + last_step = csv["Total steps"].values + results.append(last_step[-1]) + time.sleep(1) assert results[-1] >= Def.Consts.num_hs, \ Def.Consts.ASSERT_MSG.format("bigger than " + Def.Consts.num_hs, @@ -475,3 +481,18 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, elif flag[0] == "-c" or flag[0] == "--use_cpu": pass + + elif flag[0] == "-n" or flag[0] == "--num_workers": + + """ + -n, --num_workers: Once selected alone, check that csv created for each + worker, and check results. + """ + # wait until files created + csv_path = get_csv_path(clres=clres, extra_tries=20) + + expected_files = int(flag[1]) + assert len(csv_path) >= expected_files, \ + Def.Consts.ASSERT_MSG.format(str(expected_files), + str(len(csv_path))) + diff --git a/rl_coach/tests/utils/definitions.py b/rl_coach/tests/utils/definitions.py index 56740e0..9f853fb 100644 --- a/rl_coach/tests/utils/definitions.py +++ b/rl_coach/tests/utils/definitions.py @@ -39,7 +39,7 @@ class Definitions: enw = "num_workers" fw_ten = "framework_tensorflow" fw_mx = "framework_mxnet" - et = "rl_coach.environments.gym_environment:Atari" + # et = "rl_coach.environments.gym_environment:Atari" TODO """ Arguments that can be tested for python coach command @@ -92,46 +92,16 @@ class Definitions: ] # Preset for testing seed argument - seed_args_test = [ - "Atari_A3C", - "Atari_A3C_LSTM", - "Atari_Bootstrapped_DQN", - "Atari_C51", - "Atari_DDQN", - "Atari_DQN_with_PER", + args_for_seed_test = [ "Atari_DQN", - "Atari_DQN_with_PER", - "Atari_Dueling_DDQN", - "Atari_Dueling_DDQN_with_PER_OpenAI", - "Atari_NStepQ", - "Atari_QR_DQN", - "Atari_Rainbow", - "Atari_UCB_with_Q_Ensembles", + "Doom_Basic_DQN", "BitFlip_DQN", - "BitFlip_DQN_HER", - "CartPole_A3C", - "CartPole_ClippedPPO", - "CartPole_DFP", "CartPole_DQN", - "CartPole_Dueling_DDQN", - "CartPole_NStepQ", - "CartPole_PAL", - "CartPole_PG", + "CARLA_Dueling_DDQN", "ControlSuite_DDPG", - "ExplorationChain_Bootstrapped_DQN", "ExplorationChain_Dueling_DDQN", - "ExplorationChain_UCB_Q_ensembles", "Fetch_DDPG_HER_baselines", - "InvertedPendulum_PG", - "MontezumaRevenge_BC", - "Mujoco_A3C", - "Mujoco_A3C_LSTM", "Mujoco_ClippedPPO", - "Mujoco_DDPG", - "Mujoco_NAF", - "Mujoco_PPO", - "Pendulum_HAC", - "Starcraft_CollectMinerals_A3C", "Starcraft_CollectMinerals_Dueling_DDQN", ] diff --git a/rl_coach/tests/utils/presets_utils.py b/rl_coach/tests/utils/presets_utils.py index 0dc8a06..3ac8e9a 100644 --- a/rl_coach/tests/utils/presets_utils.py +++ b/rl_coach/tests/utils/presets_utils.py @@ -16,6 +16,7 @@ """Manage all preset""" import os +import pytest from importlib import import_module from rl_coach.tests.utils.definitions import Definitions as Def @@ -26,7 +27,13 @@ def import_preset(preset_name): :param preset_name: preset name :return: imported module """ - return import_module('{}.presets.{}'.format(Def.GROUP_NAME, preset_name)) + try: + module = import_module('{}.presets.{}' + .format(Def.GROUP_NAME, preset_name)) + except: + pytest.skip("Can't import module: {}".format(preset_name)) + + return module def validation_params(preset_name): diff --git a/rl_coach/tests/utils/test_utils.py b/rl_coach/tests/utils/test_utils.py index 85b72bb..03d54dc 100644 --- a/rl_coach/tests/utils/test_utils.py +++ b/rl_coach/tests/utils/test_utils.py @@ -84,7 +84,7 @@ def get_files_from_dir(dir_path): :return: |list| return files in folder """ start_time = time.time() - entities = None + entities = [] while time.time() - start_time < Def.TimeOuts.wait_for_files: # wait until logs created if os.path.exists(dir_path): @@ -118,17 +118,15 @@ def find_string_in_logs(log_path, str, timeout=Def.TimeOuts.wait_for_files, if not os.path.exists(log_path): return False - with open(log_path, 'r') as fr: - if str in fr.read(): - return True - fr.close() - - while time.time() - start_time < Def.TimeOuts.test_time_limit \ - and wait_and_find: + while time.time() - start_time < Def.TimeOuts.test_time_limit: with open(log_path, 'r') as fr: if str in fr.read(): return True fr.close() + + if not wait_and_find: + break + return False