From b3db9ce77d05161091b792fc50623bc258054b19 Mon Sep 17 00:00:00 2001 From: anabwan <46447582+anabwan@users.noreply.github.com> Date: Tue, 23 Apr 2019 15:12:11 +0300 Subject: [PATCH] tests: fixed failed tests - stabling CI (#298) * tests: stabling CI * tests: fix failed tests - stabling CI * fix get csv files. - fixed seed test * fix clres on conftest - now can modify paths during test run. - this fixed the mxnet checkpoint test * tests: fix comments --- rl_coach/tests/conftest.py | 21 ++++++++++++++++--- rl_coach/tests/test_checkpoint.py | 6 ++++++ rl_coach/tests/utils/args_utils.py | 4 ++-- rl_coach/tests/utils/test_utils.py | 33 +++++++++++++++++------------- 4 files changed, 45 insertions(+), 19 deletions(-) diff --git a/rl_coach/tests/conftest.py b/rl_coach/tests/conftest.py index 2886483..c3266d3 100644 --- a/rl_coach/tests/conftest.py +++ b/rl_coach/tests/conftest.py @@ -82,9 +82,25 @@ def clres(request): """ def __init__(self, csv, log, pattern): self.exp_path = csv - self.stdout = log + self.stdout = open(log, 'w') self.fn_pattern = pattern + @property + def experiment_path(self): + return self.exp_path + + @property + def stdout_path(self): + return self.stdout + + @experiment_path.setter + def experiment_path(self, val): + self.exp_path = val + + @stdout_path.setter + def stdout_path(self, val): + self.stdout = open(val, 'w') + # get preset name from test request params idx = 0 if 'preset' in list(request.node.funcargs.items())[0][0] else 1 p_name = list(request.node.funcargs.items())[idx][1] @@ -99,10 +115,9 @@ def clres(request): # get the stdout for logs results log_file_name = 'test_log_{}.txt'.format(p_name) - stdout = open(log_file_name, 'w') fn_pattern = '*.csv' if p_valid_params.num_workers > 1 else 'worker_0*.csv' - res = CreateCsvLog(test_path, stdout, fn_pattern) + res = CreateCsvLog(test_path, log_file_name, fn_pattern) yield res diff --git a/rl_coach/tests/test_checkpoint.py b/rl_coach/tests/test_checkpoint.py index 1884bc3..1b0eed2 100644 --- a/rl_coach/tests/test_checkpoint.py +++ b/rl_coach/tests/test_checkpoint.py @@ -91,7 +91,13 @@ def test_restore_checkpoint(preset_args, clres, framework, return p if framework == "mxnet": + # update preset name - for mxnet framework we are using *_DQN preset_args = Def.Presets.mxnet_args_test[0] + # update logs paths + test_name = 'ExpName_{}'.format(preset_args) + test_path = os.path.join(Def.Path.experiments, test_name) + clres.experiment_path = test_path + clres.stdout_path = 'test_log_{}.txt'.format(preset_args) p_valid_params = p_utils.validation_params(preset_args) create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5']) diff --git a/rl_coach/tests/utils/args_utils.py b/rl_coach/tests/utils/args_utils.py index 3384faf..568d398 100644 --- a/rl_coach/tests/utils/args_utils.py +++ b/rl_coach/tests/utils/args_utils.py @@ -453,10 +453,10 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, num_expected_files=int(flag[1])) assert len(csv_path) > 0, \ - Def.Consts.ASSERT_MSG.format("paths are not found", csv_path) + Def.Consts.ASSERT_MSG.format("paths are not found", str(csv_path)) assert int(flag[1]) == len(csv_path), Def.Consts.ASSERT_MSG. \ - format(len(csv_path), int(flag[1])) + format(int(flag[1]), len(csv_path)) # wait for getting results in csv's for i in range(len(csv_path)): diff --git a/rl_coach/tests/utils/test_utils.py b/rl_coach/tests/utils/test_utils.py index d58e5b0..c2e39ee 100644 --- a/rl_coach/tests/utils/test_utils.py +++ b/rl_coach/tests/utils/test_utils.py @@ -32,7 +32,6 @@ def print_progress(averaged_rewards, last_num_episodes, start_time, time_limit, :param start_time: start time of test :param time_limit: time out of test :param p_valid_params: preset validation parameters - :return: """ max_episodes_to_archive = p_valid_params.max_episodes_to_achieve_reward min_reward = p_valid_params.min_reward_threshold @@ -55,30 +54,36 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=120, extra_tries=0, num_expected_files=None): """ Return file path once it found - :param test_path: test folder path - :param filename_pattern: csv file pattern - :param read_csv_tries: number of iterations until file found - :param extra_tries: add number of extra tries to check after getting all - the paths. + :param test_path: |string| test folder path + :param filename_pattern: |string| csv file pattern + :param read_csv_tries: |int| number of iterations until file found + :param extra_tries: |int| add number of extra tries to check after getting + all the paths. :param num_expected_files: find all expected file in experiment folder. :return: |string| return csv file path """ csv_paths = [] tries_counter = 0 - while not csv_paths or extra_tries > 0: - csv_paths = glob.glob(path.join(test_path, '*', filename_pattern)) - if tries_counter > read_csv_tries: - break - if num_expected_files and num_expected_files == len(csv_paths): + if isinstance(extra_tries, int) and extra_tries >= 0: + read_csv_tries += extra_tries + + while tries_counter < read_csv_tries: + csv_paths = glob.glob(path.join(test_path, '*', filename_pattern)) + + if num_expected_files: + if num_expected_files == len(csv_paths): + break + else: + time.sleep(1) + tries_counter += 1 + continue + elif csv_paths: break time.sleep(1) tries_counter += 1 - if csv_paths and extra_tries > 0: - extra_tries -= 1 - return csv_paths