diff --git a/requirements.txt b/requirements.txt index f64ffe2..db1d27b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ kubernetes>=8.0.0b1 redis>=2.10.6 minio>=4.0.5 pytest>=3.8.2 +psutil>=5.5.0 diff --git a/rl_coach/coach.py b/rl_coach/coach.py index a3ded7e..d820d6d 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -454,7 +454,7 @@ class CoachLauncher(object): "effect and the CPU will be used either way.", action='store_true') parser.add_argument('-ew', '--evaluation_worker', - help="(int) If multiple workers are used, add an evaluation worker as well which will " + help="(flag) If multiple workers are used, add an evaluation worker as well which will " "evaluate asynchronously and independently during the training. NOTE: this worker will " "ignore the evaluation settings in the preset's ScheduleParams.", action='store_true') diff --git a/rl_coach/tests/README.md b/rl_coach/tests/README.md index d07f35a..e051986 100644 --- a/rl_coach/tests/README.md +++ b/rl_coach/tests/README.md @@ -46,7 +46,7 @@ several parts, each testing the framework in different areas and strictness. The golden tests can be run using the following command: ``` - python3 rl_coach/tests/test_golden.py + python3 -m pytest rl_coach/tests -m golden_test ``` * **Trace tests** - @@ -59,3 +59,19 @@ several parts, each testing the framework in different areas and strictness. ``` python3 rl_coach/tests/trace_tests.py -prl ``` + +* **Optional PyTest Flags** - + + Using -k expr to select tests based on their name; + The -k command line option to specify an expression which implements a substring match on the test names + instead of the exact match on markers that -m provides. This makes it easy to select tests based on their names: + ``` + python3 -m pytest rl_coach/tests -k Doom + ``` + Using -v (--verbose) expr to show tests progress during running the tests, -v can be added with -m or with -k, to use -v see + the following commands: + ``` + python3 -m pytest rl_coach/tests -v -m golden_test + OR + python3 -m pytest rl_coach/tests -v -k Doom + ``` \ No newline at end of file diff --git a/rl_coach/tests/conftest.py b/rl_coach/tests/conftest.py new file mode 100644 index 0000000..4cfdd1e --- /dev/null +++ b/rl_coach/tests/conftest.py @@ -0,0 +1,127 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""PyTest configuration.""" + +import configparser as cfgparser +import os +import platform +import shutil +import pytest +import rl_coach.tests.utils.args_utils as a_utils +import rl_coach.tests.utils.presets_utils as p_utils +from rl_coach.tests.utils.definitions import Definitions as Def +from os import path + + +def pytest_collection_modifyitems(config, items): + """pytest built in method to pre-process cli options""" + global test_config + test_config = cfgparser.ConfigParser() + str_rootdir = str(config.rootdir) + str_inifile = str(config.inifile) + # Get the relative path of the inifile + # By default is an absolute path but relative path when -c option used + config_path = os.path.relpath(str_inifile, str_rootdir) + config_path = os.path.join(str_rootdir, config_path) + assert (os.path.exists(config_path)) + test_config.read(config_path) + + +def pytest_runtest_setup(item): + """Called before test is run.""" + if len(item.own_markers) < 1: + return + if (item.own_markers[0].name == "unstable" and + "unstable" not in item.config.getoption("-m")): + pytest.skip("skipping unstable test") + + if item.own_markers[0].name == "linux_only": + if platform.system() != 'Linux': + pytest.skip("Skipping test that isn't Linux OS.") + + if item.own_markers[0].name == "golden_test": + """ do some custom configuration for golden tests. """ + # TODO: add custom functionality + pass + + +@pytest.fixture(scope="module", params=list(p_utils.collect_presets())) +def preset_name(request): + """ + Return all preset names + """ + return request.param + + +@pytest.fixture(scope="function", params=list(a_utils.collect_args())) +def flag(request): + """ + Return flags names in function scope + """ + return request.param + + +@pytest.fixture(scope="module", params=list(a_utils.collect_preset_for_args())) +def preset_args(request): + """ + Return preset names that can be used for args testing only; working in + module scope + """ + return request.param + + +@pytest.fixture(scope="function") +def clres(request): + """ + Create both file csv and log for testing + :yield: class of both files paths + """ + + class CreateCsvLog: + """ + Create a test and log paths + """ + def __init__(self, csv, log, pattern): + self.exp_path = csv + self.stdout = log + self.fn_pattern = pattern + + # get preset name from test request params + idx = 0 if 'preset' in list(request.node.funcargs.items())[0][0] else 1 + p_name = list(request.node.funcargs.items())[idx][1] + + p_valid_params = p_utils.validation_params(p_name) + + test_name = 'ExpName_{}'.format(p_name) + test_path = os.path.join(Def.Path.experiments, test_name) + if path.exists(test_path): + shutil.rmtree(test_path) + + # get the stdout for logs results + log_file_name = 'test_log_{}.txt'.format(p_name) + stdout = open(log_file_name, 'w') + fn_pattern = 'worker_0*.csv' if p_valid_params.num_workers > 1 else '*.csv' + + res = CreateCsvLog(test_path, stdout, fn_pattern) + + yield res + + # clean files + if path.exists(res.exp_path): + shutil.rmtree(res.exp_path) + + if os.path.exists(res.exp_path): + os.remove(res.stdout) diff --git a/rl_coach/tests/test_args.py b/rl_coach/tests/test_args.py new file mode 100644 index 0000000..deb5779 --- /dev/null +++ b/rl_coach/tests/test_args.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import subprocess +import time +import rl_coach.tests.utils.args_utils as a_utils +import rl_coach.tests.utils.presets_utils as p_utils +from rl_coach.tests.utils.definitions import Definitions as Def + + +def test_preset_args(preset_args, flag, clres, start_time=time.time(), + time_limit=Def.TimeOuts.test_time_limit): + """ Test command arguments - the test will check all flags one-by-one.""" + + p_valid_params = p_utils.validation_params(preset_args) + + run_cmd = [ + 'python3', 'rl_coach/coach.py', + '-p', '{}'.format(preset_args), + '-e', '{}'.format("ExpName_" + preset_args), + ] + + if p_valid_params.reward_test_level: + lvl = ['-lvl', '{}'.format(p_valid_params.reward_test_level)] + run_cmd.extend(lvl) + + # add flags to run command + test_flag = a_utils.add_one_flag_value(flag=flag) + run_cmd.extend(test_flag) + print(str(run_cmd)) + + # run command + p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout) + + # validate results + a_utils.validate_args_results(test_flag, clres, p, start_time, time_limit) + + # Close process + p.kill() diff --git a/rl_coach/tests/utils/__init__.py b/rl_coach/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl_coach/tests/utils/args_utils.py b/rl_coach/tests/utils/args_utils.py new file mode 100644 index 0000000..e990cfd --- /dev/null +++ b/rl_coach/tests/utils/args_utils.py @@ -0,0 +1,354 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Manage all command arguments.""" + +import os +import re +import signal +import time + +import psutil as psutil + +from rl_coach.logger import screen +from rl_coach.tests.utils import test_utils +from rl_coach.tests.utils.definitions import Definitions as Def + + +def collect_preset_for_args(): + """ + Collect presets that relevant for args testing only. + This used for testing arguments for specific presets that defined in the + definitions (args_test under Presets). + :return: preset(s) list + """ + for pn in Def.Presets.args_test: + assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn) + yield pn + + +def collect_args(): + """ + Collect args from the cmd args list - on each test iteration, it will + yield one line (one arg). + :yield: one arg foe each test iteration + """ + for k, v in Def.Flags.cmd_args.items(): + cmd = [] + cmd.append(k) + if v is not None: + cmd.append(v) + assert cmd, Def.Consts.ASSERT_MSG.format("cmd array", str(cmd)) + yield cmd + + +def add_one_flag_value(flag): + """ + Add value to flag format in order to run the python command with arguments. + :param flag: dict flag + :return: flag with format + """ + if not flag or len(flag) > 2 or len(flag) == 0: + return [] + + if len(flag) == 1: + return flag + + if Def.Flags.css in flag[1]: + flag[1] = 30 + + elif Def.Flags.crd in flag[1]: + # TODO: check dir of checkpoint + flag[1] = os.path.join(Def.Path.experiments) + + elif Def.Flags.et in flag[1]: + # TODO: add valid value + flag[1] = "" + + elif Def.Flags.ept in flag[1]: + # TODO: add valid value + flag[1] = "" + + elif Def.Flags.cp in flag[1]: + # TODO: add valid value + flag[1] = "" + + elif Def.Flags.seed in flag[1]: + flag[1] = 0 + + elif Def.Flags.dccp in flag[1]: + # TODO: add valid value + flag[1] = "" + + return flag + + +def check_files_in_dir(dir_path): + """ + Check if folder has files + :param dir_path: |string| folder path + :return: |Array| return files in folder + """ + start_time = time.time() + entities = None + while time.time() - start_time < Def.TimeOuts.wait_for_files: + # wait until logs created + if os.path.exists(dir_path): + entities = os.listdir(dir_path) + if len(entities) > 0: + break + time.sleep(1) + + assert len(entities) > 0, \ + Def.Consts.ASSERT_MSG.format("num files > 0", len(entities)) + return entities + + +def find_string_in_logs(log_path, str): + """ + Find string into the log file + :param log_path: |string| log path + :param str: |string| search text + :return: |bool| true if string found in the log file + """ + start_time = time.time() + while time.time() - start_time < Def.TimeOuts.wait_for_files: + # wait until logs created + if os.path.exists(log_path): + break + time.sleep(1) + + if not os.path.exists(log_path): + return False + + if str in open(log_path, 'r').read(): + return True + return False + + +def get_csv_path(clres): + """ + Get the csv path with the results - reading csv paths will take some time + :param clres: object of files that test is creating + :return: |Array| csv path + """ + return test_utils.read_csv_paths(test_path=clres.exp_path, + filename_pattern=clres.fn_pattern) + + +def validate_args_results(flag, clres=None, process=None, start_time=None, + timeout=None): + """ + Validate results of one argument. + :param flag: flag to check + :param clres: object of files paths (results of test experiment) + :param process: process object + :param start_time: start time of the test + :param timeout: timeout of the test- fail test once over the timeout + """ + + if flag[0] == "-ns" or flag[0] == "--no-summary": + """ + --no-summary: Once selected, summary lines shouldn't appear in logs + """ + # send CTRL+C to close experiment + process.send_signal(signal.SIGINT) + + assert not find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.RESULTS_SORTED), \ + Def.Consts.ASSERT_MSG.format("No Result summary", + Def.Consts.RESULTS_SORTED) + + assert not find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.TOTAL_RUNTIME), \ + Def.Consts.ASSERT_MSG.format("No Total runtime summary", + Def.Consts.TOTAL_RUNTIME) + + assert not find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.DISCARD_EXP), \ + Def.Consts.ASSERT_MSG.format("No discard message", + Def.Consts.DISCARD_EXP) + + elif flag[0] == "-asc" or flag[0] == "--apply_stop_condition": + """ + -asc, --apply_stop_condition: Once selected, coach stopped when + required success rate reached + """ + while time.time() - start_time < timeout: + + if find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.REACHED_REQ_ASC): + assert True, Def.Consts.ASSERT_MSG. \ + format(Def.Consts.REACHED_REQ_ASC, "Message Not Found") + break + + elif flag[0] == "-d" or flag[0] == "--open_dashboard": + """ + -d, --open_dashboard: Once selected, firefox browser will open to show + coach's Dashboard. + """ + proc_id = None + start_time = time.time() + while time.time() - start_time < Def.TimeOuts.wait_for_files: + for proc in psutil.process_iter(): + if proc.name() == Def.DASHBOARD_PROC: + assert proc.name() == Def.DASHBOARD_PROC, \ + Def.Consts.ASSERT_MSG. format(Def.DASHBOARD_PROC, + proc.name()) + proc_id = proc.pid + break + if proc_id: + break + + if proc_id: + # kill firefox process + os.kill(proc_id, signal.SIGKILL) + else: + assert False, Def.Consts.ASSERT_MSG.format("Found Firefox process", + proc_id) + + elif flag[0] == "--print_networks_summary": + """ + --print_networks_summary: Once selected, agent summary should appear in + stdout. + """ + if find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.INPUT_EMBEDDER): + assert True, Def.Consts.ASSERT_MSG.format( + Def.Consts.INPUT_EMBEDDER, "Not found") + + if find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.MIDDLEWARE): + assert True, Def.Consts.ASSERT_MSG.format( + Def.Consts.MIDDLEWARE, "Not found") + + if find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.OUTPUT_HEAD): + assert True, Def.Consts.ASSERT_MSG.format( + Def.Consts.OUTPUT_HEAD, "Not found") + + elif flag[0] == "-tb" or flag[0] == "--tensorboard": + """ + -tb, --tensorboard: Once selected, a new folder should be created in + experiment folder. + """ + csv_path = get_csv_path(clres) + assert len(csv_path) > 0, \ + Def.Consts.ASSERT_MSG.format("path not found", csv_path) + + exp_path = os.path.dirname(csv_path[0]) + tensorboard_path = os.path.join(exp_path, Def.Path.tensorboard) + + assert os.path.isdir(tensorboard_path), \ + Def.Consts.ASSERT_MSG.format("tensorboard path", tensorboard_path) + + # check if folder contain files + check_files_in_dir(dir_path=tensorboard_path) + + elif flag[0] == "-onnx" or flag[0] == "--export_onnx_graph": + """ + -onnx, --export_onnx_graph: Once selected, warning message should + appear, it should be with another flag. + """ + if find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.ONNX_WARNING): + assert True, Def.Consts.ASSERT_MSG.format( + Def.Consts.ONNX_WARNING, "Not found") + + elif flag[0] == "-dg" or flag[0] == "--dump_gifs": + """ + -dg, --dump_gifs: Once selected, a new folder should be created in + experiment folder for gifs files. + """ + csv_path = get_csv_path(clres) + assert len(csv_path) > 0, \ + Def.Consts.ASSERT_MSG.format("path not found", csv_path) + + exp_path = os.path.dirname(csv_path[0]) + gifs_path = os.path.join(exp_path, Def.Path.gifs) + + # wait until gif folder were created + while time.time() - start_time < timeout: + if os.path.isdir(gifs_path): + assert os.path.isdir(gifs_path), \ + Def.Consts.ASSERT_MSG.format("gifs path", gifs_path) + break + + # check if folder contain files + check_files_in_dir(dir_path=gifs_path) + # TODO: check if play window is opened + + elif flag[0] == "-dm" or flag[0] == "--dump_mp4": + """ + -dm, --dump_mp4: Once selected, a new folder should be created in + experiment folder for videos files. + """ + csv_path = get_csv_path(clres) + assert len(csv_path) > 0, \ + Def.Consts.ASSERT_MSG.format("path not found", csv_path) + + exp_path = os.path.dirname(csv_path[0]) + videos_path = os.path.join(exp_path, Def.Path.videos) + + # wait until video folder were created + while time.time() - start_time < timeout: + if os.path.isdir(videos_path): + assert os.path.isdir(videos_path), \ + Def.Consts.ASSERT_MSG.format("videos path", videos_path) + break + + # check if folder contain files + check_files_in_dir(dir_path=videos_path) + # TODO: check if play window is opened + + elif flag[0] == "--nocolor": + """ + --nocolor: Once selected, check if color prefix is replacing the actual + color; example: '## agent: ...' + """ + while time.time() - start_time < timeout: + + if find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.COLOR_PREFIX): + assert True, Def.Consts.ASSERT_MSG. \ + format(Def.Consts.COLOR_PREFIX, "Color Prefix Not Found") + break + + elif flag[0] == "--evaluate": + """ + --evaluate: Once selected, Coach start testing, there is not training. + """ + tries = 5 + while time.time() - start_time < timeout and tries > 0: + + if find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.TRAINING): + assert False, Def.Consts.ASSERT_MSG.format( + "Training Not Found", Def.Consts.TRAINING) + else: + time.sleep(1) + tries -= 1 + assert True, Def.Consts.ASSERT_MSG.format("Training Found", + Def.Consts.TRAINING) + + elif flag[0] == "--play": + """ + --play: Once selected alone, warning message should appear, it should + be with another flag. + """ + if find_string_in_logs(log_path=clres.stdout.name, + str=Def.Consts.PLAY_WARNING): + assert True, Def.Consts.ASSERT_MSG.format( + Def.Consts.ONNX_WARNING, "Not found") diff --git a/rl_coach/tests/utils/definitions.py b/rl_coach/tests/utils/definitions.py new file mode 100644 index 0000000..e79dd71 --- /dev/null +++ b/rl_coach/tests/utils/definitions.py @@ -0,0 +1,111 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Definitions file: + +It's main functionality are: +1) housing project constants and enums. +2) housing configuration parameters. +3) housing resource paths. +""" + + +class Definitions: + GROUP_NAME = "rl_coach" + PROCESS_NAME = "coach" + DASHBOARD_PROC = "firefox" + + class Flags: + css = "checkpoint_save_secs" + crd = "checkpoint_restore_dir" + et = "environment_type" + ept = "exploration_policy_type" + cp = "custom_parameter" + seed = "seed" + dccp = "distributed_coach_config_path" + + """ + Arguments that can be tested for python coach command + ** None = Flag - no need for string or int + ** {} = Add format for this parameter + """ + cmd_args = { + # '-l': None, + # '-e': '{}', + # '-r': None, + # '-n': '{' + enw + '}', + # '-c': None, + # '-ew': None, + '--play': None, + '--evaluate': None, + # '-v': None, + # '-tfv': None, + '--nocolor': None, + # '-s': '{' + css + '}', + # '-crd': '{' + crd + '}', + '-dg': None, + '-dm': None, + # '-et': '{' + et + '}', + # '-ept': '{' + ept + '}', + # '-lvl': '{level}', + # '-cp': '{' + cp + '}', + '--print_networks_summary': None, + '-tb': None, + '-ns': None, + '-d': None, + # '--seed': '{' + seed + '}', + '-onnx': None, + '-dc': None, + # '-dcp': '{' + dccp + '}', + '-asc': None, + '--dump_worker_logs': None, + } + + class Presets: + # Preset list for testing the flags/ arguments of python coach command + args_test = [ + "CartPole_A3C", + # "CartPole_NEC", + ] + + class Path: + experiments = "./experiments" + tensorboard = "tensorboard" + gifs = "gifs" + videos = "videos" + + class Consts: + ASSERT_MSG = "Expected: {}, Actual: {}." + RESULTS_SORTED = "Results stored at:" + TOTAL_RUNTIME = "Total runtime:" + DISCARD_EXP = "Do you want to discard the experiment results" + REACHED_REQ_ASC = "Reached required success rate. Exiting." + INPUT_EMBEDDER = "Input Embedder:" + MIDDLEWARE = "Middleware:" + OUTPUT_HEAD = "Output Head:" + ONNX_WARNING = "Exporting ONNX graphs requires setting the " \ + "--checkpoint_save_secs flag. The --export_onnx_graph" \ + " will have no effect." + COLOR_PREFIX = "## agent: Starting evaluation phase" + TRAINING = "Training - " + PLAY_WARNING = "Both the --preset and the --play flags were set. " \ + "These flags can not be used together. For human " \ + "control, please use the --play flag together with " \ + "the environment type flag (-et)" + + class TimeOuts: + test_time_limit = 60 * 60 + wait_for_files = 20 diff --git a/rl_coach/tests/utils/presets_utils.py b/rl_coach/tests/utils/presets_utils.py new file mode 100644 index 0000000..0dc8a06 --- /dev/null +++ b/rl_coach/tests/utils/presets_utils.py @@ -0,0 +1,85 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Manage all preset""" + +import os +from importlib import import_module +from rl_coach.tests.utils.definitions import Definitions as Def + + +def import_preset(preset_name): + """ + Import preset name module from presets directory + :param preset_name: preset name + :return: imported module + """ + return import_module('{}.presets.{}'.format(Def.GROUP_NAME, preset_name)) + + +def validation_params(preset_name): + """ + Validate parameters based on preset name + :param preset_name: preset name + :return: |bool| true if preset has params + """ + return import_preset(preset_name).graph_manager.preset_validation_params + + +def all_presets(): + """ + Get all preset from preset directory + :return: |Array| preset list + """ + return [ + f[:-3] for f in os.listdir(os.path.join(Def.GROUP_NAME, 'presets')) + if f[-3:] == '.py' and not f == '__init__.py' + ] + + +def importable(preset_name): + """ + Try to import preset name + :param preset_name: |name| preset name + :return: |bool| true if possible to import preset + """ + try: + import_preset(preset_name) + return True + except BaseException: + return False + + +def has_test_parameters(preset_name): + """ + Check if preset has parameters + :param preset_name: |string| preset name + :return: |bool| true: if preset have parameters + """ + return bool(validation_params(preset_name).test) + + +def collect_presets(): + """ + Collect all presets in presets directory + :yield: preset name + """ + for preset_name in all_presets(): + # if it isn't importable, still include it so we can fail the test + if not importable(preset_name): + yield preset_name + # otherwise, make sure it has test parameters before including it + elif has_test_parameters(preset_name): + yield preset_name diff --git a/rl_coach/tests/utils/test_utils.py b/rl_coach/tests/utils/test_utils.py new file mode 100644 index 0000000..f9cdbd1 --- /dev/null +++ b/rl_coach/tests/utils/test_utils.py @@ -0,0 +1,68 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Common functionality shared across tests.""" + +import glob +import sys +import time +from os import path + + +def print_progress(averaged_rewards, last_num_episodes, start_time, time_limit, + p_valid_params): + """ + Print progress bar for preset run test + :param averaged_rewards: average rewards of test + :param last_num_episodes: last episode number + :param start_time: start time of test + :param time_limit: time out of test + :param p_valid_params: preset validation parameters + :return: + """ + max_episodes_to_archive = p_valid_params.max_episodes_to_achieve_reward + min_reward = p_valid_params.min_reward_threshold + avg_reward = round(averaged_rewards[-1], 1) + percentage = int((100 * last_num_episodes) / max_episodes_to_archive) + cur_time = round(time.time() - start_time, 2) + + sys.stdout.write("\rReward: ({}/{})".format(avg_reward, min_reward)) + sys.stdout.write(' Time (sec): ({}/{})'.format(cur_time, time_limit)) + sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes, + max_episodes_to_archive)) + sys.stdout.write(' {}%|{}{}| ' + .format(percentage, '#' * int(percentage / 10), + ' ' * (10 - int(percentage / 10)))) + + sys.stdout.flush() + + +def read_csv_paths(test_path, filename_pattern, read_csv_tries=120): + """ + Return file path once it found + :param test_path: test folder path + :param filename_pattern: csv file pattern + :param read_csv_tries: number of iterations until file found + :return: |string| return csv file path + """ + csv_paths = [] + tries_counter = 0 + while not csv_paths: + csv_paths = glob.glob(path.join(test_path, '*', filename_pattern)) + if tries_counter > read_csv_tries: + break + tries_counter += 1 + time.sleep(1) + return csv_paths