mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
tests: added new setup configuration + test args (#211)
- added utils for future tests and conftest - added test args
This commit is contained in:
@@ -14,3 +14,4 @@ kubernetes>=8.0.0b1
|
|||||||
redis>=2.10.6
|
redis>=2.10.6
|
||||||
minio>=4.0.5
|
minio>=4.0.5
|
||||||
pytest>=3.8.2
|
pytest>=3.8.2
|
||||||
|
psutil>=5.5.0
|
||||||
|
|||||||
@@ -454,7 +454,7 @@ class CoachLauncher(object):
|
|||||||
"effect and the CPU will be used either way.",
|
"effect and the CPU will be used either way.",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
parser.add_argument('-ew', '--evaluation_worker',
|
parser.add_argument('-ew', '--evaluation_worker',
|
||||||
help="(int) If multiple workers are used, add an evaluation worker as well which will "
|
help="(flag) If multiple workers are used, add an evaluation worker as well which will "
|
||||||
"evaluate asynchronously and independently during the training. NOTE: this worker will "
|
"evaluate asynchronously and independently during the training. NOTE: this worker will "
|
||||||
"ignore the evaluation settings in the preset's ScheduleParams.",
|
"ignore the evaluation settings in the preset's ScheduleParams.",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ several parts, each testing the framework in different areas and strictness.
|
|||||||
The golden tests can be run using the following command:
|
The golden tests can be run using the following command:
|
||||||
|
|
||||||
```
|
```
|
||||||
python3 rl_coach/tests/test_golden.py
|
python3 -m pytest rl_coach/tests -m golden_test
|
||||||
```
|
```
|
||||||
|
|
||||||
* **Trace tests** -
|
* **Trace tests** -
|
||||||
@@ -59,3 +59,19 @@ several parts, each testing the framework in different areas and strictness.
|
|||||||
```
|
```
|
||||||
python3 rl_coach/tests/trace_tests.py -prl
|
python3 rl_coach/tests/trace_tests.py -prl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
* **Optional PyTest Flags** -
|
||||||
|
|
||||||
|
Using -k expr to select tests based on their name;
|
||||||
|
The -k command line option to specify an expression which implements a substring match on the test names
|
||||||
|
instead of the exact match on markers that -m provides. This makes it easy to select tests based on their names:
|
||||||
|
```
|
||||||
|
python3 -m pytest rl_coach/tests -k Doom
|
||||||
|
```
|
||||||
|
Using -v (--verbose) expr to show tests progress during running the tests, -v can be added with -m or with -k, to use -v see
|
||||||
|
the following commands:
|
||||||
|
```
|
||||||
|
python3 -m pytest rl_coach/tests -v -m golden_test
|
||||||
|
OR
|
||||||
|
python3 -m pytest rl_coach/tests -v -k Doom
|
||||||
|
```
|
||||||
127
rl_coach/tests/conftest.py
Normal file
127
rl_coach/tests/conftest.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""PyTest configuration."""
|
||||||
|
|
||||||
|
import configparser as cfgparser
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
import rl_coach.tests.utils.args_utils as a_utils
|
||||||
|
import rl_coach.tests.utils.presets_utils as p_utils
|
||||||
|
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
"""pytest built in method to pre-process cli options"""
|
||||||
|
global test_config
|
||||||
|
test_config = cfgparser.ConfigParser()
|
||||||
|
str_rootdir = str(config.rootdir)
|
||||||
|
str_inifile = str(config.inifile)
|
||||||
|
# Get the relative path of the inifile
|
||||||
|
# By default is an absolute path but relative path when -c option used
|
||||||
|
config_path = os.path.relpath(str_inifile, str_rootdir)
|
||||||
|
config_path = os.path.join(str_rootdir, config_path)
|
||||||
|
assert (os.path.exists(config_path))
|
||||||
|
test_config.read(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_runtest_setup(item):
|
||||||
|
"""Called before test is run."""
|
||||||
|
if len(item.own_markers) < 1:
|
||||||
|
return
|
||||||
|
if (item.own_markers[0].name == "unstable" and
|
||||||
|
"unstable" not in item.config.getoption("-m")):
|
||||||
|
pytest.skip("skipping unstable test")
|
||||||
|
|
||||||
|
if item.own_markers[0].name == "linux_only":
|
||||||
|
if platform.system() != 'Linux':
|
||||||
|
pytest.skip("Skipping test that isn't Linux OS.")
|
||||||
|
|
||||||
|
if item.own_markers[0].name == "golden_test":
|
||||||
|
""" do some custom configuration for golden tests. """
|
||||||
|
# TODO: add custom functionality
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=list(p_utils.collect_presets()))
|
||||||
|
def preset_name(request):
|
||||||
|
"""
|
||||||
|
Return all preset names
|
||||||
|
"""
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", params=list(a_utils.collect_args()))
|
||||||
|
def flag(request):
|
||||||
|
"""
|
||||||
|
Return flags names in function scope
|
||||||
|
"""
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=list(a_utils.collect_preset_for_args()))
|
||||||
|
def preset_args(request):
|
||||||
|
"""
|
||||||
|
Return preset names that can be used for args testing only; working in
|
||||||
|
module scope
|
||||||
|
"""
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def clres(request):
|
||||||
|
"""
|
||||||
|
Create both file csv and log for testing
|
||||||
|
:yield: class of both files paths
|
||||||
|
"""
|
||||||
|
|
||||||
|
class CreateCsvLog:
|
||||||
|
"""
|
||||||
|
Create a test and log paths
|
||||||
|
"""
|
||||||
|
def __init__(self, csv, log, pattern):
|
||||||
|
self.exp_path = csv
|
||||||
|
self.stdout = log
|
||||||
|
self.fn_pattern = pattern
|
||||||
|
|
||||||
|
# get preset name from test request params
|
||||||
|
idx = 0 if 'preset' in list(request.node.funcargs.items())[0][0] else 1
|
||||||
|
p_name = list(request.node.funcargs.items())[idx][1]
|
||||||
|
|
||||||
|
p_valid_params = p_utils.validation_params(p_name)
|
||||||
|
|
||||||
|
test_name = 'ExpName_{}'.format(p_name)
|
||||||
|
test_path = os.path.join(Def.Path.experiments, test_name)
|
||||||
|
if path.exists(test_path):
|
||||||
|
shutil.rmtree(test_path)
|
||||||
|
|
||||||
|
# get the stdout for logs results
|
||||||
|
log_file_name = 'test_log_{}.txt'.format(p_name)
|
||||||
|
stdout = open(log_file_name, 'w')
|
||||||
|
fn_pattern = 'worker_0*.csv' if p_valid_params.num_workers > 1 else '*.csv'
|
||||||
|
|
||||||
|
res = CreateCsvLog(test_path, stdout, fn_pattern)
|
||||||
|
|
||||||
|
yield res
|
||||||
|
|
||||||
|
# clean files
|
||||||
|
if path.exists(res.exp_path):
|
||||||
|
shutil.rmtree(res.exp_path)
|
||||||
|
|
||||||
|
if os.path.exists(res.exp_path):
|
||||||
|
os.remove(res.stdout)
|
||||||
51
rl_coach/tests/test_args.py
Normal file
51
rl_coach/tests/test_args.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import rl_coach.tests.utils.args_utils as a_utils
|
||||||
|
import rl_coach.tests.utils.presets_utils as p_utils
|
||||||
|
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||||
|
|
||||||
|
|
||||||
|
def test_preset_args(preset_args, flag, clres, start_time=time.time(),
|
||||||
|
time_limit=Def.TimeOuts.test_time_limit):
|
||||||
|
""" Test command arguments - the test will check all flags one-by-one."""
|
||||||
|
|
||||||
|
p_valid_params = p_utils.validation_params(preset_args)
|
||||||
|
|
||||||
|
run_cmd = [
|
||||||
|
'python3', 'rl_coach/coach.py',
|
||||||
|
'-p', '{}'.format(preset_args),
|
||||||
|
'-e', '{}'.format("ExpName_" + preset_args),
|
||||||
|
]
|
||||||
|
|
||||||
|
if p_valid_params.reward_test_level:
|
||||||
|
lvl = ['-lvl', '{}'.format(p_valid_params.reward_test_level)]
|
||||||
|
run_cmd.extend(lvl)
|
||||||
|
|
||||||
|
# add flags to run command
|
||||||
|
test_flag = a_utils.add_one_flag_value(flag=flag)
|
||||||
|
run_cmd.extend(test_flag)
|
||||||
|
print(str(run_cmd))
|
||||||
|
|
||||||
|
# run command
|
||||||
|
p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
|
||||||
|
|
||||||
|
# validate results
|
||||||
|
a_utils.validate_args_results(test_flag, clres, p, start_time, time_limit)
|
||||||
|
|
||||||
|
# Close process
|
||||||
|
p.kill()
|
||||||
0
rl_coach/tests/utils/__init__.py
Normal file
0
rl_coach/tests/utils/__init__.py
Normal file
354
rl_coach/tests/utils/args_utils.py
Normal file
354
rl_coach/tests/utils/args_utils.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""Manage all command arguments."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
|
||||||
|
import psutil as psutil
|
||||||
|
|
||||||
|
from rl_coach.logger import screen
|
||||||
|
from rl_coach.tests.utils import test_utils
|
||||||
|
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||||
|
|
||||||
|
|
||||||
|
def collect_preset_for_args():
|
||||||
|
"""
|
||||||
|
Collect presets that relevant for args testing only.
|
||||||
|
This used for testing arguments for specific presets that defined in the
|
||||||
|
definitions (args_test under Presets).
|
||||||
|
:return: preset(s) list
|
||||||
|
"""
|
||||||
|
for pn in Def.Presets.args_test:
|
||||||
|
assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn)
|
||||||
|
yield pn
|
||||||
|
|
||||||
|
|
||||||
|
def collect_args():
|
||||||
|
"""
|
||||||
|
Collect args from the cmd args list - on each test iteration, it will
|
||||||
|
yield one line (one arg).
|
||||||
|
:yield: one arg foe each test iteration
|
||||||
|
"""
|
||||||
|
for k, v in Def.Flags.cmd_args.items():
|
||||||
|
cmd = []
|
||||||
|
cmd.append(k)
|
||||||
|
if v is not None:
|
||||||
|
cmd.append(v)
|
||||||
|
assert cmd, Def.Consts.ASSERT_MSG.format("cmd array", str(cmd))
|
||||||
|
yield cmd
|
||||||
|
|
||||||
|
|
||||||
|
def add_one_flag_value(flag):
|
||||||
|
"""
|
||||||
|
Add value to flag format in order to run the python command with arguments.
|
||||||
|
:param flag: dict flag
|
||||||
|
:return: flag with format
|
||||||
|
"""
|
||||||
|
if not flag or len(flag) > 2 or len(flag) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(flag) == 1:
|
||||||
|
return flag
|
||||||
|
|
||||||
|
if Def.Flags.css in flag[1]:
|
||||||
|
flag[1] = 30
|
||||||
|
|
||||||
|
elif Def.Flags.crd in flag[1]:
|
||||||
|
# TODO: check dir of checkpoint
|
||||||
|
flag[1] = os.path.join(Def.Path.experiments)
|
||||||
|
|
||||||
|
elif Def.Flags.et in flag[1]:
|
||||||
|
# TODO: add valid value
|
||||||
|
flag[1] = ""
|
||||||
|
|
||||||
|
elif Def.Flags.ept in flag[1]:
|
||||||
|
# TODO: add valid value
|
||||||
|
flag[1] = ""
|
||||||
|
|
||||||
|
elif Def.Flags.cp in flag[1]:
|
||||||
|
# TODO: add valid value
|
||||||
|
flag[1] = ""
|
||||||
|
|
||||||
|
elif Def.Flags.seed in flag[1]:
|
||||||
|
flag[1] = 0
|
||||||
|
|
||||||
|
elif Def.Flags.dccp in flag[1]:
|
||||||
|
# TODO: add valid value
|
||||||
|
flag[1] = ""
|
||||||
|
|
||||||
|
return flag
|
||||||
|
|
||||||
|
|
||||||
|
def check_files_in_dir(dir_path):
|
||||||
|
"""
|
||||||
|
Check if folder has files
|
||||||
|
:param dir_path: |string| folder path
|
||||||
|
:return: |Array| return files in folder
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
entities = None
|
||||||
|
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
||||||
|
# wait until logs created
|
||||||
|
if os.path.exists(dir_path):
|
||||||
|
entities = os.listdir(dir_path)
|
||||||
|
if len(entities) > 0:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
assert len(entities) > 0, \
|
||||||
|
Def.Consts.ASSERT_MSG.format("num files > 0", len(entities))
|
||||||
|
return entities
|
||||||
|
|
||||||
|
|
||||||
|
def find_string_in_logs(log_path, str):
|
||||||
|
"""
|
||||||
|
Find string into the log file
|
||||||
|
:param log_path: |string| log path
|
||||||
|
:param str: |string| search text
|
||||||
|
:return: |bool| true if string found in the log file
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
||||||
|
# wait until logs created
|
||||||
|
if os.path.exists(log_path):
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
if not os.path.exists(log_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if str in open(log_path, 'r').read():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_csv_path(clres):
|
||||||
|
"""
|
||||||
|
Get the csv path with the results - reading csv paths will take some time
|
||||||
|
:param clres: object of files that test is creating
|
||||||
|
:return: |Array| csv path
|
||||||
|
"""
|
||||||
|
return test_utils.read_csv_paths(test_path=clres.exp_path,
|
||||||
|
filename_pattern=clres.fn_pattern)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_args_results(flag, clres=None, process=None, start_time=None,
|
||||||
|
timeout=None):
|
||||||
|
"""
|
||||||
|
Validate results of one argument.
|
||||||
|
:param flag: flag to check
|
||||||
|
:param clres: object of files paths (results of test experiment)
|
||||||
|
:param process: process object
|
||||||
|
:param start_time: start time of the test
|
||||||
|
:param timeout: timeout of the test- fail test once over the timeout
|
||||||
|
"""
|
||||||
|
|
||||||
|
if flag[0] == "-ns" or flag[0] == "--no-summary":
|
||||||
|
"""
|
||||||
|
--no-summary: Once selected, summary lines shouldn't appear in logs
|
||||||
|
"""
|
||||||
|
# send CTRL+C to close experiment
|
||||||
|
process.send_signal(signal.SIGINT)
|
||||||
|
|
||||||
|
assert not find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.RESULTS_SORTED), \
|
||||||
|
Def.Consts.ASSERT_MSG.format("No Result summary",
|
||||||
|
Def.Consts.RESULTS_SORTED)
|
||||||
|
|
||||||
|
assert not find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.TOTAL_RUNTIME), \
|
||||||
|
Def.Consts.ASSERT_MSG.format("No Total runtime summary",
|
||||||
|
Def.Consts.TOTAL_RUNTIME)
|
||||||
|
|
||||||
|
assert not find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.DISCARD_EXP), \
|
||||||
|
Def.Consts.ASSERT_MSG.format("No discard message",
|
||||||
|
Def.Consts.DISCARD_EXP)
|
||||||
|
|
||||||
|
elif flag[0] == "-asc" or flag[0] == "--apply_stop_condition":
|
||||||
|
"""
|
||||||
|
-asc, --apply_stop_condition: Once selected, coach stopped when
|
||||||
|
required success rate reached
|
||||||
|
"""
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
|
||||||
|
if find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.REACHED_REQ_ASC):
|
||||||
|
assert True, Def.Consts.ASSERT_MSG. \
|
||||||
|
format(Def.Consts.REACHED_REQ_ASC, "Message Not Found")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif flag[0] == "-d" or flag[0] == "--open_dashboard":
|
||||||
|
"""
|
||||||
|
-d, --open_dashboard: Once selected, firefox browser will open to show
|
||||||
|
coach's Dashboard.
|
||||||
|
"""
|
||||||
|
proc_id = None
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
||||||
|
for proc in psutil.process_iter():
|
||||||
|
if proc.name() == Def.DASHBOARD_PROC:
|
||||||
|
assert proc.name() == Def.DASHBOARD_PROC, \
|
||||||
|
Def.Consts.ASSERT_MSG. format(Def.DASHBOARD_PROC,
|
||||||
|
proc.name())
|
||||||
|
proc_id = proc.pid
|
||||||
|
break
|
||||||
|
if proc_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
if proc_id:
|
||||||
|
# kill firefox process
|
||||||
|
os.kill(proc_id, signal.SIGKILL)
|
||||||
|
else:
|
||||||
|
assert False, Def.Consts.ASSERT_MSG.format("Found Firefox process",
|
||||||
|
proc_id)
|
||||||
|
|
||||||
|
elif flag[0] == "--print_networks_summary":
|
||||||
|
"""
|
||||||
|
--print_networks_summary: Once selected, agent summary should appear in
|
||||||
|
stdout.
|
||||||
|
"""
|
||||||
|
if find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.INPUT_EMBEDDER):
|
||||||
|
assert True, Def.Consts.ASSERT_MSG.format(
|
||||||
|
Def.Consts.INPUT_EMBEDDER, "Not found")
|
||||||
|
|
||||||
|
if find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.MIDDLEWARE):
|
||||||
|
assert True, Def.Consts.ASSERT_MSG.format(
|
||||||
|
Def.Consts.MIDDLEWARE, "Not found")
|
||||||
|
|
||||||
|
if find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.OUTPUT_HEAD):
|
||||||
|
assert True, Def.Consts.ASSERT_MSG.format(
|
||||||
|
Def.Consts.OUTPUT_HEAD, "Not found")
|
||||||
|
|
||||||
|
elif flag[0] == "-tb" or flag[0] == "--tensorboard":
|
||||||
|
"""
|
||||||
|
-tb, --tensorboard: Once selected, a new folder should be created in
|
||||||
|
experiment folder.
|
||||||
|
"""
|
||||||
|
csv_path = get_csv_path(clres)
|
||||||
|
assert len(csv_path) > 0, \
|
||||||
|
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||||
|
|
||||||
|
exp_path = os.path.dirname(csv_path[0])
|
||||||
|
tensorboard_path = os.path.join(exp_path, Def.Path.tensorboard)
|
||||||
|
|
||||||
|
assert os.path.isdir(tensorboard_path), \
|
||||||
|
Def.Consts.ASSERT_MSG.format("tensorboard path", tensorboard_path)
|
||||||
|
|
||||||
|
# check if folder contain files
|
||||||
|
check_files_in_dir(dir_path=tensorboard_path)
|
||||||
|
|
||||||
|
elif flag[0] == "-onnx" or flag[0] == "--export_onnx_graph":
|
||||||
|
"""
|
||||||
|
-onnx, --export_onnx_graph: Once selected, warning message should
|
||||||
|
appear, it should be with another flag.
|
||||||
|
"""
|
||||||
|
if find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.ONNX_WARNING):
|
||||||
|
assert True, Def.Consts.ASSERT_MSG.format(
|
||||||
|
Def.Consts.ONNX_WARNING, "Not found")
|
||||||
|
|
||||||
|
elif flag[0] == "-dg" or flag[0] == "--dump_gifs":
|
||||||
|
"""
|
||||||
|
-dg, --dump_gifs: Once selected, a new folder should be created in
|
||||||
|
experiment folder for gifs files.
|
||||||
|
"""
|
||||||
|
csv_path = get_csv_path(clres)
|
||||||
|
assert len(csv_path) > 0, \
|
||||||
|
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||||
|
|
||||||
|
exp_path = os.path.dirname(csv_path[0])
|
||||||
|
gifs_path = os.path.join(exp_path, Def.Path.gifs)
|
||||||
|
|
||||||
|
# wait until gif folder were created
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
if os.path.isdir(gifs_path):
|
||||||
|
assert os.path.isdir(gifs_path), \
|
||||||
|
Def.Consts.ASSERT_MSG.format("gifs path", gifs_path)
|
||||||
|
break
|
||||||
|
|
||||||
|
# check if folder contain files
|
||||||
|
check_files_in_dir(dir_path=gifs_path)
|
||||||
|
# TODO: check if play window is opened
|
||||||
|
|
||||||
|
elif flag[0] == "-dm" or flag[0] == "--dump_mp4":
|
||||||
|
"""
|
||||||
|
-dm, --dump_mp4: Once selected, a new folder should be created in
|
||||||
|
experiment folder for videos files.
|
||||||
|
"""
|
||||||
|
csv_path = get_csv_path(clres)
|
||||||
|
assert len(csv_path) > 0, \
|
||||||
|
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||||
|
|
||||||
|
exp_path = os.path.dirname(csv_path[0])
|
||||||
|
videos_path = os.path.join(exp_path, Def.Path.videos)
|
||||||
|
|
||||||
|
# wait until video folder were created
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
if os.path.isdir(videos_path):
|
||||||
|
assert os.path.isdir(videos_path), \
|
||||||
|
Def.Consts.ASSERT_MSG.format("videos path", videos_path)
|
||||||
|
break
|
||||||
|
|
||||||
|
# check if folder contain files
|
||||||
|
check_files_in_dir(dir_path=videos_path)
|
||||||
|
# TODO: check if play window is opened
|
||||||
|
|
||||||
|
elif flag[0] == "--nocolor":
|
||||||
|
"""
|
||||||
|
--nocolor: Once selected, check if color prefix is replacing the actual
|
||||||
|
color; example: '## agent: ...'
|
||||||
|
"""
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
|
||||||
|
if find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.COLOR_PREFIX):
|
||||||
|
assert True, Def.Consts.ASSERT_MSG. \
|
||||||
|
format(Def.Consts.COLOR_PREFIX, "Color Prefix Not Found")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif flag[0] == "--evaluate":
|
||||||
|
"""
|
||||||
|
--evaluate: Once selected, Coach start testing, there is not training.
|
||||||
|
"""
|
||||||
|
tries = 5
|
||||||
|
while time.time() - start_time < timeout and tries > 0:
|
||||||
|
|
||||||
|
if find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.TRAINING):
|
||||||
|
assert False, Def.Consts.ASSERT_MSG.format(
|
||||||
|
"Training Not Found", Def.Consts.TRAINING)
|
||||||
|
else:
|
||||||
|
time.sleep(1)
|
||||||
|
tries -= 1
|
||||||
|
assert True, Def.Consts.ASSERT_MSG.format("Training Found",
|
||||||
|
Def.Consts.TRAINING)
|
||||||
|
|
||||||
|
elif flag[0] == "--play":
|
||||||
|
"""
|
||||||
|
--play: Once selected alone, warning message should appear, it should
|
||||||
|
be with another flag.
|
||||||
|
"""
|
||||||
|
if find_string_in_logs(log_path=clres.stdout.name,
|
||||||
|
str=Def.Consts.PLAY_WARNING):
|
||||||
|
assert True, Def.Consts.ASSERT_MSG.format(
|
||||||
|
Def.Consts.ONNX_WARNING, "Not found")
|
||||||
111
rl_coach/tests/utils/definitions.py
Normal file
111
rl_coach/tests/utils/definitions.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
Definitions file:
|
||||||
|
|
||||||
|
It's main functionality are:
|
||||||
|
1) housing project constants and enums.
|
||||||
|
2) housing configuration parameters.
|
||||||
|
3) housing resource paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Definitions:
|
||||||
|
GROUP_NAME = "rl_coach"
|
||||||
|
PROCESS_NAME = "coach"
|
||||||
|
DASHBOARD_PROC = "firefox"
|
||||||
|
|
||||||
|
class Flags:
|
||||||
|
css = "checkpoint_save_secs"
|
||||||
|
crd = "checkpoint_restore_dir"
|
||||||
|
et = "environment_type"
|
||||||
|
ept = "exploration_policy_type"
|
||||||
|
cp = "custom_parameter"
|
||||||
|
seed = "seed"
|
||||||
|
dccp = "distributed_coach_config_path"
|
||||||
|
|
||||||
|
"""
|
||||||
|
Arguments that can be tested for python coach command
|
||||||
|
** None = Flag - no need for string or int
|
||||||
|
** {} = Add format for this parameter
|
||||||
|
"""
|
||||||
|
cmd_args = {
|
||||||
|
# '-l': None,
|
||||||
|
# '-e': '{}',
|
||||||
|
# '-r': None,
|
||||||
|
# '-n': '{' + enw + '}',
|
||||||
|
# '-c': None,
|
||||||
|
# '-ew': None,
|
||||||
|
'--play': None,
|
||||||
|
'--evaluate': None,
|
||||||
|
# '-v': None,
|
||||||
|
# '-tfv': None,
|
||||||
|
'--nocolor': None,
|
||||||
|
# '-s': '{' + css + '}',
|
||||||
|
# '-crd': '{' + crd + '}',
|
||||||
|
'-dg': None,
|
||||||
|
'-dm': None,
|
||||||
|
# '-et': '{' + et + '}',
|
||||||
|
# '-ept': '{' + ept + '}',
|
||||||
|
# '-lvl': '{level}',
|
||||||
|
# '-cp': '{' + cp + '}',
|
||||||
|
'--print_networks_summary': None,
|
||||||
|
'-tb': None,
|
||||||
|
'-ns': None,
|
||||||
|
'-d': None,
|
||||||
|
# '--seed': '{' + seed + '}',
|
||||||
|
'-onnx': None,
|
||||||
|
'-dc': None,
|
||||||
|
# '-dcp': '{' + dccp + '}',
|
||||||
|
'-asc': None,
|
||||||
|
'--dump_worker_logs': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
class Presets:
|
||||||
|
# Preset list for testing the flags/ arguments of python coach command
|
||||||
|
args_test = [
|
||||||
|
"CartPole_A3C",
|
||||||
|
# "CartPole_NEC",
|
||||||
|
]
|
||||||
|
|
||||||
|
class Path:
|
||||||
|
experiments = "./experiments"
|
||||||
|
tensorboard = "tensorboard"
|
||||||
|
gifs = "gifs"
|
||||||
|
videos = "videos"
|
||||||
|
|
||||||
|
class Consts:
|
||||||
|
ASSERT_MSG = "Expected: {}, Actual: {}."
|
||||||
|
RESULTS_SORTED = "Results stored at:"
|
||||||
|
TOTAL_RUNTIME = "Total runtime:"
|
||||||
|
DISCARD_EXP = "Do you want to discard the experiment results"
|
||||||
|
REACHED_REQ_ASC = "Reached required success rate. Exiting."
|
||||||
|
INPUT_EMBEDDER = "Input Embedder:"
|
||||||
|
MIDDLEWARE = "Middleware:"
|
||||||
|
OUTPUT_HEAD = "Output Head:"
|
||||||
|
ONNX_WARNING = "Exporting ONNX graphs requires setting the " \
|
||||||
|
"--checkpoint_save_secs flag. The --export_onnx_graph" \
|
||||||
|
" will have no effect."
|
||||||
|
COLOR_PREFIX = "## agent: Starting evaluation phase"
|
||||||
|
TRAINING = "Training - "
|
||||||
|
PLAY_WARNING = "Both the --preset and the --play flags were set. " \
|
||||||
|
"These flags can not be used together. For human " \
|
||||||
|
"control, please use the --play flag together with " \
|
||||||
|
"the environment type flag (-et)"
|
||||||
|
|
||||||
|
class TimeOuts:
|
||||||
|
test_time_limit = 60 * 60
|
||||||
|
wait_for_files = 20
|
||||||
85
rl_coach/tests/utils/presets_utils.py
Normal file
85
rl_coach/tests/utils/presets_utils.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""Manage all preset"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from importlib import import_module
|
||||||
|
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||||
|
|
||||||
|
|
||||||
|
def import_preset(preset_name):
|
||||||
|
"""
|
||||||
|
Import preset name module from presets directory
|
||||||
|
:param preset_name: preset name
|
||||||
|
:return: imported module
|
||||||
|
"""
|
||||||
|
return import_module('{}.presets.{}'.format(Def.GROUP_NAME, preset_name))
|
||||||
|
|
||||||
|
|
||||||
|
def validation_params(preset_name):
|
||||||
|
"""
|
||||||
|
Validate parameters based on preset name
|
||||||
|
:param preset_name: preset name
|
||||||
|
:return: |bool| true if preset has params
|
||||||
|
"""
|
||||||
|
return import_preset(preset_name).graph_manager.preset_validation_params
|
||||||
|
|
||||||
|
|
||||||
|
def all_presets():
|
||||||
|
"""
|
||||||
|
Get all preset from preset directory
|
||||||
|
:return: |Array| preset list
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
f[:-3] for f in os.listdir(os.path.join(Def.GROUP_NAME, 'presets'))
|
||||||
|
if f[-3:] == '.py' and not f == '__init__.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def importable(preset_name):
|
||||||
|
"""
|
||||||
|
Try to import preset name
|
||||||
|
:param preset_name: |name| preset name
|
||||||
|
:return: |bool| true if possible to import preset
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import_preset(preset_name)
|
||||||
|
return True
|
||||||
|
except BaseException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def has_test_parameters(preset_name):
|
||||||
|
"""
|
||||||
|
Check if preset has parameters
|
||||||
|
:param preset_name: |string| preset name
|
||||||
|
:return: |bool| true: if preset have parameters
|
||||||
|
"""
|
||||||
|
return bool(validation_params(preset_name).test)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_presets():
|
||||||
|
"""
|
||||||
|
Collect all presets in presets directory
|
||||||
|
:yield: preset name
|
||||||
|
"""
|
||||||
|
for preset_name in all_presets():
|
||||||
|
# if it isn't importable, still include it so we can fail the test
|
||||||
|
if not importable(preset_name):
|
||||||
|
yield preset_name
|
||||||
|
# otherwise, make sure it has test parameters before including it
|
||||||
|
elif has_test_parameters(preset_name):
|
||||||
|
yield preset_name
|
||||||
68
rl_coach/tests/utils/test_utils.py
Normal file
68
rl_coach/tests/utils/test_utils.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""Common functionality shared across tests."""
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
|
||||||
|
def print_progress(averaged_rewards, last_num_episodes, start_time, time_limit,
|
||||||
|
p_valid_params):
|
||||||
|
"""
|
||||||
|
Print progress bar for preset run test
|
||||||
|
:param averaged_rewards: average rewards of test
|
||||||
|
:param last_num_episodes: last episode number
|
||||||
|
:param start_time: start time of test
|
||||||
|
:param time_limit: time out of test
|
||||||
|
:param p_valid_params: preset validation parameters
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
max_episodes_to_archive = p_valid_params.max_episodes_to_achieve_reward
|
||||||
|
min_reward = p_valid_params.min_reward_threshold
|
||||||
|
avg_reward = round(averaged_rewards[-1], 1)
|
||||||
|
percentage = int((100 * last_num_episodes) / max_episodes_to_archive)
|
||||||
|
cur_time = round(time.time() - start_time, 2)
|
||||||
|
|
||||||
|
sys.stdout.write("\rReward: ({}/{})".format(avg_reward, min_reward))
|
||||||
|
sys.stdout.write(' Time (sec): ({}/{})'.format(cur_time, time_limit))
|
||||||
|
sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes,
|
||||||
|
max_episodes_to_archive))
|
||||||
|
sys.stdout.write(' {}%|{}{}| '
|
||||||
|
.format(percentage, '#' * int(percentage / 10),
|
||||||
|
' ' * (10 - int(percentage / 10))))
|
||||||
|
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def read_csv_paths(test_path, filename_pattern, read_csv_tries=120):
|
||||||
|
"""
|
||||||
|
Return file path once it found
|
||||||
|
:param test_path: test folder path
|
||||||
|
:param filename_pattern: csv file pattern
|
||||||
|
:param read_csv_tries: number of iterations until file found
|
||||||
|
:return: |string| return csv file path
|
||||||
|
"""
|
||||||
|
csv_paths = []
|
||||||
|
tries_counter = 0
|
||||||
|
while not csv_paths:
|
||||||
|
csv_paths = glob.glob(path.join(test_path, '*', filename_pattern))
|
||||||
|
if tries_counter > read_csv_tries:
|
||||||
|
break
|
||||||
|
tries_counter += 1
|
||||||
|
time.sleep(1)
|
||||||
|
return csv_paths
|
||||||
Reference in New Issue
Block a user