From 881f78f45a1ca319bf1daf78ae8fe58dcddcf9c5 Mon Sep 17 00:00:00 2001 From: anabwan <46447582+anabwan@users.noreply.github.com> Date: Sun, 7 Apr 2019 07:36:44 +0300 Subject: [PATCH] tests: new checkpoint mxnet test + fix utils (#273) * tests: new mxnet test + fix utils new test added: - test_restore_checkpoint[tensorflow, mxnet] fix failed tests in CI improve utils * tests: fix comments for mxnet checkpoint test and utils --- rl_coach/tests/test_checkpoint.py | 28 ++++++++++++++++++++++++---- rl_coach/tests/utils/args_utils.py | 15 ++++++++------- rl_coach/tests/utils/test_utils.py | 13 ++++++++++--- 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/rl_coach/tests/test_checkpoint.py b/rl_coach/tests/test_checkpoint.py index b72e999..1da9f4a 100644 --- a/rl_coach/tests/test_checkpoint.py +++ b/rl_coach/tests/test_checkpoint.py @@ -55,24 +55,44 @@ def test_get_checkpoint_state(): @pytest.mark.functional_test -def test_restore_checkpoint(preset_args, clres, start_time=time.time(), +@pytest.mark.parametrize("framework", ["mxnet", "tensorflow"]) +def test_restore_checkpoint(preset_args, clres, framework, + start_time=time.time(), timeout=Def.TimeOuts.test_time_limit): - """ Create checkpoint and restore them in second run.""" + """ + Create checkpoints and restore them in second run. + :param preset_args: all preset that can be tested for argument tests + :param clres: logs and csv files + :param framework: name of the test framework + :param start_time: test started time + :param timeout: max time for test + """ def _create_cmd_and_run(flag): - + """ + Create default command with given flag and run it + :param flag: name of the tested flag, this flag will be extended to the + running command line + :return: active process + """ run_cmd = [ 'python3', 'rl_coach/coach.py', '-p', '{}'.format(preset_args), '-e', '{}'.format("ExpName_" + preset_args), + '--seed', '{}'.format(42), + '-f', '{}'.format(framework), ] + test_flag = a_utils.add_one_flag_value(flag=flag) run_cmd.extend(test_flag) - + print(str(run_cmd)) p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout) return p + if framework == "mxnet": + preset_args = Def.Presets.mxnet_args_test + p_valid_params = p_utils.validation_params(preset_args) create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5']) diff --git a/rl_coach/tests/utils/args_utils.py b/rl_coach/tests/utils/args_utils.py index 324a633..e0b0c48 100644 --- a/rl_coach/tests/utils/args_utils.py +++ b/rl_coach/tests/utils/args_utils.py @@ -395,9 +395,9 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, results.append(last_step[-1]) time.sleep(1) - assert results[-1] >= Def.Consts.num_hs, \ - Def.Consts.ASSERT_MSG.format("bigger than " + Def.Consts.num_hs, - results[-1]) + assert int(results[-1]) >= Def.Consts.num_hs, \ + Def.Consts.ASSERT_MSG.format("bigger than " + + str(Def.Consts.num_hs), results[-1]) elif flag[0] == "-f" or flag[0] == "--framework": """ @@ -445,7 +445,8 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, """ lst_csv = [] # wait until files created - csv_path = get_csv_path(clres=clres, extra_tries=10) + csv_path = get_csv_path(clres=clres, extra_tries=20, + num_expected_files=int(flag[1])) assert len(csv_path) > 0, \ Def.Consts.ASSERT_MSG.format("paths are not found", csv_path) @@ -491,8 +492,8 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None, # wait until files created csv_path = get_csv_path(clres=clres, extra_tries=20) - expected_files = int(flag[1]) - assert len(csv_path) >= expected_files, \ - Def.Consts.ASSERT_MSG.format(str(expected_files), + num_expected_files = int(flag[1]) + assert len(csv_path) >= num_expected_files, \ + Def.Consts.ASSERT_MSG.format(str(num_expected_files), str(len(csv_path))) diff --git a/rl_coach/tests/utils/test_utils.py b/rl_coach/tests/utils/test_utils.py index 03d54dc..d58e5b0 100644 --- a/rl_coach/tests/utils/test_utils.py +++ b/rl_coach/tests/utils/test_utils.py @@ -52,7 +52,7 @@ def print_progress(averaged_rewards, last_num_episodes, start_time, time_limit, def read_csv_paths(test_path, filename_pattern, read_csv_tries=120, - extra_tries=0): + extra_tries=0, num_expected_files=None): """ Return file path once it found :param test_path: test folder path @@ -60,6 +60,7 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=120, :param read_csv_tries: number of iterations until file found :param extra_tries: add number of extra tries to check after getting all the paths. + :param num_expected_files: find all expected file in experiment folder. :return: |string| return csv file path """ csv_paths = [] @@ -68,6 +69,10 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=120, csv_paths = glob.glob(path.join(test_path, '*', filename_pattern)) if tries_counter > read_csv_tries: break + + if num_expected_files and num_expected_files == len(csv_paths): + break + time.sleep(1) tries_counter += 1 @@ -131,17 +136,19 @@ def find_string_in_logs(log_path, str, timeout=Def.TimeOuts.wait_for_files, def get_csv_path(clres, tries_for_csv=Def.TimeOuts.wait_for_csv, - extra_tries=0): + extra_tries=0, num_expected_files=None): """ Get the csv path with the results - reading csv paths will take some time :param clres: object of files that test is creating :param tries_for_csv: timeout of tires until getting all csv files :param extra_tries: add number of extra tries to check after getting all the paths. + :param num_expected_files: find all expected file in experiment folder. :return: |list| csv path """ return read_csv_paths(test_path=clres.exp_path, filename_pattern=clres.fn_pattern, read_csv_tries=tries_for_csv, - extra_tries=extra_tries) + extra_tries=extra_tries, + num_expected_files=num_expected_files)