mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
tests: added new tests + utils code improved (#221)
* tests: added new tests + utils code improved * new tests: - test_preset_args_combination - test_preset_mxnet_framework * added more flags to test_preset_args * added validation for flags in utils * tests: added new tests + fixed utils * tests: added new checkpoint test * tests: added checkpoint test improve utils * tests: added tests + improve validations * bump integration CI run timeout. * tests: improve timerun + add functional test marker
This commit is contained in:
@@ -157,7 +157,7 @@ jobs:
|
||||
name: run integration tests
|
||||
command: |
|
||||
python3 rl_coach/tests/test_eks.py -c coach-test -bn ${CIRCLE_BUILD_NUM} -tn integration-test -tc 'make integration_tests_without_docker' -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach:$(git describe --tags --always --dirty) -cpu 2048 -mem 4096
|
||||
no_output_timeout: 20m
|
||||
no_output_timeout: 30m
|
||||
- run:
|
||||
name: cleanup
|
||||
command: |
|
||||
|
||||
@@ -15,10 +15,9 @@
|
||||
#
|
||||
"""PyTest configuration."""
|
||||
|
||||
import configparser as cfgparser
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import pytest
|
||||
import rl_coach.tests.utils.args_utils as a_utils
|
||||
import rl_coach.tests.utils.presets_utils as p_utils
|
||||
@@ -26,44 +25,12 @@ from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
from os import path
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""pytest built in method to pre-process cli options"""
|
||||
global test_config
|
||||
test_config = cfgparser.ConfigParser()
|
||||
str_rootdir = str(config.rootdir)
|
||||
str_inifile = str(config.inifile)
|
||||
# Get the relative path of the inifile
|
||||
# By default is an absolute path but relative path when -c option used
|
||||
config_path = os.path.relpath(str_inifile, str_rootdir)
|
||||
config_path = os.path.join(str_rootdir, config_path)
|
||||
assert (os.path.exists(config_path))
|
||||
test_config.read(config_path)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
"""Called before test is run."""
|
||||
if len(item.own_markers) < 1:
|
||||
return
|
||||
if (item.own_markers[0].name == "unstable" and
|
||||
"unstable" not in item.config.getoption("-m")):
|
||||
pytest.skip("skipping unstable test")
|
||||
|
||||
if item.own_markers[0].name == "linux_only":
|
||||
if platform.system() != 'Linux':
|
||||
pytest.skip("Skipping test that isn't Linux OS.")
|
||||
|
||||
if item.own_markers[0].name == "golden_test":
|
||||
""" do some custom configuration for golden tests. """
|
||||
# TODO: add custom functionality
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=list(p_utils.collect_presets()))
|
||||
def preset_name(request):
|
||||
"""
|
||||
Return all preset names
|
||||
"""
|
||||
return request.param
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", params=list(a_utils.collect_args()))
|
||||
@@ -71,7 +38,7 @@ def flag(request):
|
||||
"""
|
||||
Return flags names in function scope
|
||||
"""
|
||||
return request.param
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=list(a_utils.collect_preset_for_args()))
|
||||
@@ -80,7 +47,26 @@ def preset_args(request):
|
||||
Return preset names that can be used for args testing only; working in
|
||||
module scope
|
||||
"""
|
||||
return request.param
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=list(a_utils.collect_preset_for_seed()))
|
||||
def preset_args_for_seed(request):
|
||||
"""
|
||||
Return preset names that can be used for args testing only and for special
|
||||
action when using seed argument; working in module scope
|
||||
"""
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=list(a_utils.collect_preset_for_mxnet()))
|
||||
def preset_for_mxnet_args(request):
|
||||
"""
|
||||
Return preset names that can be used for args testing only; this special
|
||||
fixture will be used for mxnet framework only. working in module scope
|
||||
"""
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -105,6 +91,7 @@ def clres(request):
|
||||
|
||||
p_valid_params = p_utils.validation_params(p_name)
|
||||
|
||||
sys.path.append('.')
|
||||
test_name = 'ExpName_{}'.format(p_name)
|
||||
test_path = os.path.join(Def.Path.experiments, test_name)
|
||||
if path.exists(test_path):
|
||||
@@ -113,7 +100,7 @@ def clres(request):
|
||||
# get the stdout for logs results
|
||||
log_file_name = 'test_log_{}.txt'.format(p_name)
|
||||
stdout = open(log_file_name, 'w')
|
||||
fn_pattern = 'worker_0*.csv' if p_valid_params.num_workers > 1 else '*.csv'
|
||||
fn_pattern = '*.csv' if p_valid_params.num_workers > 1 else 'worker_0*.csv'
|
||||
|
||||
res = CreateCsvLog(test_path, stdout, fn_pattern)
|
||||
|
||||
@@ -123,5 +110,5 @@ def clres(request):
|
||||
if path.exists(res.exp_path):
|
||||
shutil.rmtree(res.exp_path)
|
||||
|
||||
if os.path.exists(res.exp_path):
|
||||
os.remove(res.stdout)
|
||||
if path.exists(res.stdout.name):
|
||||
os.remove(res.stdout.name)
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import subprocess
|
||||
import time
|
||||
import rl_coach.tests.utils.args_utils as a_utils
|
||||
import rl_coach.tests.utils.presets_utils as p_utils
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
|
||||
|
||||
def test_preset_args(preset_args, flag, clres, start_time=time.time(),
|
||||
time_limit=Def.TimeOuts.test_time_limit):
|
||||
""" Test command arguments - the test will check all flags one-by-one."""
|
||||
|
||||
p_valid_params = p_utils.validation_params(preset_args)
|
||||
|
||||
run_cmd = [
|
||||
'python3', 'rl_coach/coach.py',
|
||||
'-p', '{}'.format(preset_args),
|
||||
'-e', '{}'.format("ExpName_" + preset_args),
|
||||
]
|
||||
|
||||
if p_valid_params.reward_test_level:
|
||||
lvl = ['-lvl', '{}'.format(p_valid_params.reward_test_level)]
|
||||
run_cmd.extend(lvl)
|
||||
|
||||
# add flags to run command
|
||||
test_flag = a_utils.add_one_flag_value(flag=flag)
|
||||
run_cmd.extend(test_flag)
|
||||
print(str(run_cmd))
|
||||
|
||||
# run command
|
||||
p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
|
||||
|
||||
# validate results
|
||||
a_utils.validate_args_results(test_flag, clres, p, start_time, time_limit)
|
||||
|
||||
# Close process
|
||||
p.kill()
|
||||
@@ -1,24 +1,125 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import pytest
|
||||
import signal
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import rl_coach.tests.utils.args_utils as a_utils
|
||||
import rl_coach.tests.utils.test_utils as test_utils
|
||||
from rl_coach import checkpoint
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_checkpoint_state():
|
||||
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext']
|
||||
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext',
|
||||
'1.test.ckpt.ext', 'prefix.10.test.ckpt.ext']
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
[open(os.path.join(temp_dir, fn), 'a').close() for fn in files]
|
||||
checkpoint_state = checkpoint.get_checkpoint_state(temp_dir, all_checkpoints=True)
|
||||
assert checkpoint_state.model_checkpoint_path == os.path.join(temp_dir, '4.test.ckpt')
|
||||
checkpoint_state = \
|
||||
checkpoint.get_checkpoint_state(temp_dir,
|
||||
all_checkpoints=True)
|
||||
assert checkpoint_state.model_checkpoint_path == os.path.join(
|
||||
temp_dir, '4.test.ckpt')
|
||||
assert checkpoint_state.all_model_checkpoint_paths == \
|
||||
[os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])]
|
||||
|
||||
reader = checkpoint.CheckpointStateReader(temp_dir, checkpoint_state_optional=False)
|
||||
reader = \
|
||||
checkpoint.CheckpointStateReader(temp_dir,
|
||||
checkpoint_state_optional=False)
|
||||
assert reader.get_latest() is None
|
||||
assert len(reader.get_all()) == 0
|
||||
|
||||
reader = checkpoint.CheckpointStateReader(temp_dir)
|
||||
assert reader.get_latest().num == 4
|
||||
assert [ckp.num for ckp in reader.get_all()] == [1, 2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.functional_test
|
||||
def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
|
||||
""" Create checkpoint and restore them in second run."""
|
||||
|
||||
def _create_cmd_and_run(flag):
|
||||
|
||||
run_cmd = [
|
||||
'python3', 'rl_coach/coach.py',
|
||||
'-p', '{}'.format(preset_args),
|
||||
'-e', '{}'.format("ExpName_" + preset_args),
|
||||
]
|
||||
test_flag = a_utils.add_one_flag_value(flag=flag)
|
||||
run_cmd.extend(test_flag)
|
||||
|
||||
p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
|
||||
|
||||
return p
|
||||
|
||||
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])
|
||||
|
||||
# wait for checkpoint files
|
||||
csv_list = a_utils.get_csv_path(clres=clres)
|
||||
assert len(csv_list) > 0
|
||||
exp_dir = os.path.dirname(csv_list[0])
|
||||
|
||||
checkpoint_dir = os.path.join(exp_dir, Def.Path.checkpoint)
|
||||
|
||||
checkpoint_test_dir = os.path.join(Def.Path.experiments, Def.Path.test_dir)
|
||||
if os.path.exists(checkpoint_test_dir):
|
||||
shutil.rmtree(checkpoint_test_dir)
|
||||
|
||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||
|
||||
while not any("10_Step" in file for file in entities) and time.time() - \
|
||||
start_time < Def.TimeOuts.test_time_limit:
|
||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||
time.sleep(1)
|
||||
|
||||
assert len(entities) > 0
|
||||
assert "checkpoint" in entities
|
||||
assert any(".ckpt." in file for file in entities)
|
||||
|
||||
# send CTRL+C to close experiment
|
||||
create_cp_proc.send_signal(signal.SIGINT)
|
||||
|
||||
csv = pd.read_csv(csv_list[0])
|
||||
rewards = csv['Evaluation Reward'].values
|
||||
rewards = rewards[~np.isnan(rewards)]
|
||||
min_reward = np.amin(rewards)
|
||||
|
||||
if os.path.isdir(checkpoint_dir):
|
||||
shutil.copytree(exp_dir, checkpoint_test_dir)
|
||||
shutil.rmtree(exp_dir)
|
||||
|
||||
create_cp_proc.kill()
|
||||
checkpoint_test_dir = "{}/{}".format(checkpoint_test_dir,
|
||||
Def.Path.checkpoint)
|
||||
# run second time with checkpoint folder (restore)
|
||||
restore_cp_proc = _create_cmd_and_run(flag=['-crd', checkpoint_test_dir,
|
||||
'--evaluate'])
|
||||
|
||||
new_csv_list = test_utils.get_csv_path(clres=clres)
|
||||
time.sleep(10)
|
||||
|
||||
csv = pd.read_csv(new_csv_list[0])
|
||||
res = csv['Episode Length'].values[-1]
|
||||
assert res >= min_reward, \
|
||||
Def.Consts.ASSERT_MSG.format(str(res) + ">=" + str(min_reward),
|
||||
str(res) + " < " + str(min_reward))
|
||||
restore_cp_proc.kill()
|
||||
|
||||
145
rl_coach/tests/test_coach_args.py
Normal file
145
rl_coach/tests/test_coach_args.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import subprocess
|
||||
import time
|
||||
import pytest
|
||||
import rl_coach.tests.utils.args_utils as a_utils
|
||||
import rl_coach.tests.utils.presets_utils as p_utils
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
|
||||
|
||||
@pytest.mark.functional_test
|
||||
def test_preset_args(preset_args, flag, clres, start_time=time.time(),
|
||||
time_limit=Def.TimeOuts.test_time_limit):
|
||||
""" Test command arguments - the test will check all flags one-by-one."""
|
||||
|
||||
p_valid_params = p_utils.validation_params(preset_args)
|
||||
|
||||
run_cmd = [
|
||||
'python3', 'rl_coach/coach.py',
|
||||
'-p', '{}'.format(preset_args),
|
||||
'-e', '{}'.format("ExpName_" + preset_args),
|
||||
]
|
||||
|
||||
if p_valid_params.reward_test_level:
|
||||
lvl = ['-lvl', '{}'.format(p_valid_params.reward_test_level)]
|
||||
run_cmd.extend(lvl)
|
||||
|
||||
# add flags to run command
|
||||
test_flag = a_utils.add_one_flag_value(flag=flag)
|
||||
run_cmd.extend(test_flag)
|
||||
print(str(run_cmd))
|
||||
|
||||
proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
|
||||
|
||||
try:
|
||||
a_utils.validate_arg_result(flag=test_flag,
|
||||
p_valid_params=p_valid_params, clres=clres,
|
||||
process=proc, start_time=start_time,
|
||||
timeout=time_limit)
|
||||
except AssertionError:
|
||||
# close process once get assert false
|
||||
proc.kill()
|
||||
assert False
|
||||
|
||||
proc.kill()
|
||||
|
||||
|
||||
@pytest.mark.functional_test
|
||||
def test_preset_mxnet_framework(preset_for_mxnet_args, clres,
|
||||
start_time=time.time(),
|
||||
time_limit=Def.TimeOuts.test_time_limit):
|
||||
""" Test command arguments - the test will check mxnet framework"""
|
||||
|
||||
flag = ['-f', 'mxnet']
|
||||
p_valid_params = p_utils.validation_params(preset_for_mxnet_args)
|
||||
|
||||
run_cmd = [
|
||||
'python3', 'rl_coach/coach.py',
|
||||
'-p', '{}'.format(preset_for_mxnet_args),
|
||||
'-e', '{}'.format("ExpName_" + preset_for_mxnet_args),
|
||||
]
|
||||
|
||||
# add flags to run command
|
||||
test_flag = a_utils.add_one_flag_value(flag=flag)
|
||||
run_cmd.extend(test_flag)
|
||||
|
||||
print(str(run_cmd))
|
||||
|
||||
proc = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)
|
||||
|
||||
try:
|
||||
a_utils.validate_arg_result(flag=test_flag,
|
||||
p_valid_params=p_valid_params, clres=clres,
|
||||
process=proc, start_time=start_time,
|
||||
timeout=time_limit)
|
||||
except AssertionError:
|
||||
# close process once get assert false
|
||||
proc.kill()
|
||||
assert False
|
||||
|
||||
proc.kill()
|
||||
|
||||
|
||||
@pytest.mark.functional_test
|
||||
def test_preset_seed(preset_args_for_seed, clres, start_time=time.time(),
|
||||
time_limit=Def.TimeOuts.test_time_limit):
|
||||
"""
|
||||
Test command arguments - the test will check seed argument with all
|
||||
presets
|
||||
"""
|
||||
|
||||
def close_processes():
|
||||
"""
|
||||
close all processes that still active in the process list
|
||||
"""
|
||||
for i in range(seed_num):
|
||||
proc[i].kill()
|
||||
|
||||
proc = []
|
||||
seed_num = 2
|
||||
flag = ["--seed", str(seed_num)]
|
||||
p_valid_params = p_utils.validation_params(preset_args_for_seed)
|
||||
|
||||
run_cmd = [
|
||||
'python3', 'rl_coach/coach.py',
|
||||
'-p', '{}'.format(preset_args_for_seed),
|
||||
'-e', '{}'.format("ExpName_" + preset_args_for_seed),
|
||||
]
|
||||
|
||||
if p_valid_params.trace_test_levels:
|
||||
lvl = ['-lvl', '{}'.format(p_valid_params.trace_test_levels[0])]
|
||||
run_cmd.extend(lvl)
|
||||
|
||||
# add flags to run command
|
||||
test_flag = a_utils.add_one_flag_value(flag=flag)
|
||||
run_cmd.extend(test_flag)
|
||||
print(str(run_cmd))
|
||||
|
||||
for _ in range(seed_num):
|
||||
proc.append(subprocess.Popen(run_cmd, stdout=clres.stdout,
|
||||
stderr=clres.stdout))
|
||||
|
||||
try:
|
||||
a_utils.validate_arg_result(flag=test_flag,
|
||||
p_valid_params=p_valid_params, clres=clres,
|
||||
process=proc, start_time=start_time,
|
||||
timeout=time_limit)
|
||||
except AssertionError:
|
||||
close_processes()
|
||||
assert False
|
||||
|
||||
close_processes()
|
||||
@@ -16,17 +16,27 @@
|
||||
"""Manage all command arguments."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import time
|
||||
|
||||
import psutil as psutil
|
||||
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.tests.utils import test_utils
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from rl_coach.tests.utils.test_utils import get_csv_path, get_files_from_dir, \
|
||||
find_string_in_logs
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
|
||||
|
||||
def collect_preset_for_mxnet():
|
||||
"""
|
||||
Collect presets that relevant for args testing only.
|
||||
This used for testing arguments for specific presets that defined in the
|
||||
definitions (args_test under Presets).
|
||||
:return: preset(s) list
|
||||
"""
|
||||
for pn in Def.Presets.mxnet_args_test:
|
||||
assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn)
|
||||
yield pn
|
||||
|
||||
|
||||
def collect_preset_for_args():
|
||||
"""
|
||||
Collect presets that relevant for args testing only.
|
||||
@@ -39,19 +49,27 @@ def collect_preset_for_args():
|
||||
yield pn
|
||||
|
||||
|
||||
def collect_preset_for_seed():
|
||||
"""
|
||||
Collect presets that relevant for seed argument testing only.
|
||||
This used for testing arguments for specific presets that defined in the
|
||||
definitions (args_test under Presets).
|
||||
:return: preset(s) list
|
||||
"""
|
||||
for pn in Def.Presets.seed_args_test:
|
||||
assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn)
|
||||
yield pn
|
||||
|
||||
|
||||
def collect_args():
|
||||
"""
|
||||
Collect args from the cmd args list - on each test iteration, it will
|
||||
yield one line (one arg).
|
||||
:yield: one arg foe each test iteration
|
||||
"""
|
||||
for k, v in Def.Flags.cmd_args.items():
|
||||
cmd = []
|
||||
cmd.append(k)
|
||||
if v is not None:
|
||||
cmd.append(v)
|
||||
assert cmd, Def.Consts.ASSERT_MSG.format("cmd array", str(cmd))
|
||||
yield cmd
|
||||
for i in Def.Flags.cmd_args:
|
||||
assert i, Def.Consts.ASSERT_MSG.format("flag list", str(i))
|
||||
yield i
|
||||
|
||||
|
||||
def add_one_flag_value(flag):
|
||||
@@ -60,99 +78,86 @@ def add_one_flag_value(flag):
|
||||
:param flag: dict flag
|
||||
:return: flag with format
|
||||
"""
|
||||
if not flag or len(flag) > 2 or len(flag) == 0:
|
||||
if not flag or len(flag) == 0:
|
||||
return []
|
||||
|
||||
if len(flag) == 1:
|
||||
return flag
|
||||
|
||||
if Def.Flags.css in flag[1]:
|
||||
flag[1] = 30
|
||||
if Def.Flags.enw in flag[1]:
|
||||
flag[1] = '2'
|
||||
|
||||
elif Def.Flags.crd in flag[1]:
|
||||
# TODO: check dir of checkpoint
|
||||
flag[1] = os.path.join(Def.Path.experiments)
|
||||
elif Def.Flags.css in flag[1]:
|
||||
flag[1] = '5'
|
||||
|
||||
elif Def.Flags.et in flag[1]:
|
||||
# TODO: add valid value
|
||||
flag[1] = ""
|
||||
elif Def.Flags.fw_ten in flag[1]:
|
||||
flag[1] = "tensorflow"
|
||||
|
||||
elif Def.Flags.ept in flag[1]:
|
||||
# TODO: add valid value
|
||||
flag[1] = ""
|
||||
elif Def.Flags.fw_mx in flag[1]:
|
||||
flag[1] = "mxnet"
|
||||
|
||||
elif Def.Flags.cp in flag[1]:
|
||||
# TODO: add valid value
|
||||
flag[1] = ""
|
||||
|
||||
elif Def.Flags.seed in flag[1]:
|
||||
flag[1] = 0
|
||||
|
||||
elif Def.Flags.dccp in flag[1]:
|
||||
# TODO: add valid value
|
||||
flag[1] = ""
|
||||
flag[1] = "heatup_steps=EnvironmentSteps({})".format(Def.Consts.num_hs)
|
||||
|
||||
return flag
|
||||
|
||||
|
||||
def check_files_in_dir(dir_path):
|
||||
def is_reward_reached(csv_path, p_valid_params, start_time, time_limit):
|
||||
"""
|
||||
Check if folder has files
|
||||
:param dir_path: |string| folder path
|
||||
:return: |Array| return files in folder
|
||||
Check the result of the experiment, by collecting all the Evaluation Reward
|
||||
and average should be bigger than the min reward threshold.
|
||||
:param csv_path: csv file (results)
|
||||
:param p_valid_params: experiment test params
|
||||
:param start_time: start time of the test
|
||||
:param time_limit: timeout of the test
|
||||
:return: |Bool| true if reached the reward minimum
|
||||
"""
|
||||
start_time = time.time()
|
||||
entities = None
|
||||
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
||||
# wait until logs created
|
||||
if os.path.exists(dir_path):
|
||||
entities = os.listdir(dir_path)
|
||||
if len(entities) > 0:
|
||||
break
|
||||
time.sleep(1)
|
||||
win_size = 10
|
||||
last_num_episodes = 0
|
||||
csv = None
|
||||
reward_reached = False
|
||||
|
||||
assert len(entities) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("num files > 0", len(entities))
|
||||
return entities
|
||||
while csv is None or (csv['Episode #'].values[-1]
|
||||
< p_valid_params.max_episodes_to_achieve_reward and
|
||||
time.time() - start_time < time_limit):
|
||||
|
||||
csv = pd.read_csv(csv_path)
|
||||
|
||||
def find_string_in_logs(log_path, str):
|
||||
"""
|
||||
Find string into the log file
|
||||
:param log_path: |string| log path
|
||||
:param str: |string| search text
|
||||
:return: |bool| true if string found in the log file
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
||||
# wait until logs created
|
||||
if os.path.exists(log_path):
|
||||
if 'Evaluation Reward' not in csv.keys():
|
||||
continue
|
||||
|
||||
rewards = csv['Evaluation Reward'].values
|
||||
|
||||
rewards = rewards[~np.isnan(rewards)]
|
||||
if len(rewards) >= 1:
|
||||
averaged_rewards = np.convolve(rewards, np.ones(
|
||||
min(len(rewards), win_size)) / win_size, mode='valid')
|
||||
|
||||
else:
|
||||
# May be in heat-up steps
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if csv['Episode #'].shape[0] - last_num_episodes <= 0:
|
||||
continue
|
||||
|
||||
last_num_episodes = csv['Episode #'].values[-1]
|
||||
|
||||
# check if reward is enough
|
||||
if np.any(averaged_rewards >= p_valid_params.min_reward_threshold):
|
||||
reward_reached = True
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
if not os.path.exists(log_path):
|
||||
return False
|
||||
|
||||
if str in open(log_path, 'r').read():
|
||||
return True
|
||||
return False
|
||||
return reward_reached
|
||||
|
||||
|
||||
def get_csv_path(clres):
|
||||
"""
|
||||
Get the csv path with the results - reading csv paths will take some time
|
||||
:param clres: object of files that test is creating
|
||||
:return: |Array| csv path
|
||||
"""
|
||||
return test_utils.read_csv_paths(test_path=clres.exp_path,
|
||||
filename_pattern=clres.fn_pattern)
|
||||
|
||||
|
||||
def validate_args_results(flag, clres=None, process=None, start_time=None,
|
||||
timeout=None):
|
||||
def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
||||
start_time=None, timeout=Def.TimeOuts.test_time_limit):
|
||||
"""
|
||||
Validate results of one argument.
|
||||
:param flag: flag to check
|
||||
:param p_valid_params: params test per preset
|
||||
:param clres: object of files paths (results of test experiment)
|
||||
:param process: process object
|
||||
:param start_time: start time of the test
|
||||
@@ -186,38 +191,11 @@ def validate_args_results(flag, clres=None, process=None, start_time=None,
|
||||
-asc, --apply_stop_condition: Once selected, coach stopped when
|
||||
required success rate reached
|
||||
"""
|
||||
while time.time() - start_time < timeout:
|
||||
|
||||
if find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.REACHED_REQ_ASC):
|
||||
assert True, Def.Consts.ASSERT_MSG. \
|
||||
format(Def.Consts.REACHED_REQ_ASC, "Message Not Found")
|
||||
break
|
||||
|
||||
elif flag[0] == "-d" or flag[0] == "--open_dashboard":
|
||||
"""
|
||||
-d, --open_dashboard: Once selected, firefox browser will open to show
|
||||
coach's Dashboard.
|
||||
"""
|
||||
proc_id = None
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
||||
for proc in psutil.process_iter():
|
||||
if proc.name() == Def.DASHBOARD_PROC:
|
||||
assert proc.name() == Def.DASHBOARD_PROC, \
|
||||
Def.Consts.ASSERT_MSG. format(Def.DASHBOARD_PROC,
|
||||
proc.name())
|
||||
proc_id = proc.pid
|
||||
break
|
||||
if proc_id:
|
||||
break
|
||||
|
||||
if proc_id:
|
||||
# kill firefox process
|
||||
os.kill(proc_id, signal.SIGKILL)
|
||||
else:
|
||||
assert False, Def.Consts.ASSERT_MSG.format("Found Firefox process",
|
||||
proc_id)
|
||||
assert find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.REACHED_REQ_ASC,
|
||||
wait_and_find=True), \
|
||||
Def.Consts.ASSERT_MSG.format(Def.Consts.REACHED_REQ_ASC,
|
||||
"Message Not Found")
|
||||
|
||||
elif flag[0] == "--print_networks_summary":
|
||||
"""
|
||||
@@ -254,18 +232,19 @@ def validate_args_results(flag, clres=None, process=None, start_time=None,
|
||||
assert os.path.isdir(tensorboard_path), \
|
||||
Def.Consts.ASSERT_MSG.format("tensorboard path", tensorboard_path)
|
||||
|
||||
# check if folder contain files
|
||||
check_files_in_dir(dir_path=tensorboard_path)
|
||||
# check if folder contain files and check extensions
|
||||
files = get_files_from_dir(dir_path=tensorboard_path)
|
||||
assert any(".tfevents." in file for file in files)
|
||||
|
||||
elif flag[0] == "-onnx" or flag[0] == "--export_onnx_graph":
|
||||
"""
|
||||
-onnx, --export_onnx_graph: Once selected, warning message should
|
||||
appear, it should be with another flag.
|
||||
"""
|
||||
if find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.ONNX_WARNING):
|
||||
assert True, Def.Consts.ASSERT_MSG.format(
|
||||
Def.Consts.ONNX_WARNING, "Not found")
|
||||
assert find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.ONNX_WARNING,
|
||||
wait_and_find=True), \
|
||||
Def.Consts.ASSERT_MSG.format(Def.Consts.ONNX_WARNING, "Not found")
|
||||
|
||||
elif flag[0] == "-dg" or flag[0] == "--dump_gifs":
|
||||
"""
|
||||
@@ -287,7 +266,7 @@ def validate_args_results(flag, clres=None, process=None, start_time=None,
|
||||
break
|
||||
|
||||
# check if folder contain files
|
||||
check_files_in_dir(dir_path=gifs_path)
|
||||
get_files_from_dir(dir_path=gifs_path)
|
||||
# TODO: check if play window is opened
|
||||
|
||||
elif flag[0] == "-dm" or flag[0] == "--dump_mp4":
|
||||
@@ -310,7 +289,7 @@ def validate_args_results(flag, clres=None, process=None, start_time=None,
|
||||
break
|
||||
|
||||
# check if folder contain files
|
||||
check_files_in_dir(dir_path=videos_path)
|
||||
get_files_from_dir(dir_path=videos_path)
|
||||
# TODO: check if play window is opened
|
||||
|
||||
elif flag[0] == "--nocolor":
|
||||
@@ -318,37 +297,181 @@ def validate_args_results(flag, clres=None, process=None, start_time=None,
|
||||
--nocolor: Once selected, check if color prefix is replacing the actual
|
||||
color; example: '## agent: ...'
|
||||
"""
|
||||
while time.time() - start_time < timeout:
|
||||
|
||||
if find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.COLOR_PREFIX):
|
||||
assert True, Def.Consts.ASSERT_MSG. \
|
||||
format(Def.Consts.COLOR_PREFIX, "Color Prefix Not Found")
|
||||
break
|
||||
assert find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.COLOR_PREFIX,
|
||||
wait_and_find=True), \
|
||||
Def.Consts.ASSERT_MSG.format(Def.Consts.COLOR_PREFIX,
|
||||
"Color Prefix Not Found")
|
||||
|
||||
elif flag[0] == "--evaluate":
|
||||
"""
|
||||
--evaluate: Once selected, Coach start testing, there is not training.
|
||||
"""
|
||||
tries = 5
|
||||
while time.time() - start_time < timeout and tries > 0:
|
||||
|
||||
if find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.TRAINING):
|
||||
assert False, Def.Consts.ASSERT_MSG.format(
|
||||
"Training Not Found", Def.Consts.TRAINING)
|
||||
else:
|
||||
time.sleep(1)
|
||||
tries -= 1
|
||||
assert True, Def.Consts.ASSERT_MSG.format("Training Found",
|
||||
Def.Consts.TRAINING)
|
||||
# wait until files created
|
||||
get_csv_path(clres=clres)
|
||||
time.sleep(15)
|
||||
assert not find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.TRAINING), \
|
||||
Def.Consts.ASSERT_MSG.format("Training Not Found",
|
||||
Def.Consts.TRAINING)
|
||||
|
||||
elif flag[0] == "--play":
|
||||
"""
|
||||
--play: Once selected alone, warning message should appear, it should
|
||||
be with another flag.
|
||||
--play: Once selected alone, an warning message should appear, it
|
||||
should be with another flag.
|
||||
"""
|
||||
if find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.PLAY_WARNING):
|
||||
assert True, Def.Consts.ASSERT_MSG.format(
|
||||
Def.Consts.ONNX_WARNING, "Not found")
|
||||
assert find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.PLAY_WARNING,
|
||||
wait_and_find=True), \
|
||||
Def.Consts.ASSERT_MSG.format(Def.Consts.PLAY_WARNING, "Not found")
|
||||
|
||||
elif flag[0] == "-et" or flag[0] == "--environment_type":
|
||||
"""
|
||||
-et, --environment_type: Once selected check csv results is created.
|
||||
"""
|
||||
csv_path = get_csv_path(clres)
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||
|
||||
elif flag[0] == "-s" or flag[0] == "--checkpoint_save_secs":
|
||||
"""
|
||||
-s, --checkpoint_save_secs: Once selected, check if files added to the
|
||||
experiment path.
|
||||
"""
|
||||
csv_path = get_csv_path(clres)
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||
|
||||
exp_path = os.path.dirname(csv_path[0])
|
||||
checkpoint_path = os.path.join(exp_path, Def.Path.checkpoint)
|
||||
|
||||
# wait until video folder were created
|
||||
while time.time() - start_time < timeout:
|
||||
if os.path.isdir(checkpoint_path):
|
||||
assert os.path.isdir(checkpoint_path), \
|
||||
Def.Consts.ASSERT_MSG.format("checkpoint path",
|
||||
checkpoint_path)
|
||||
break
|
||||
|
||||
# check if folder contain files
|
||||
get_files_from_dir(dir_path=checkpoint_path)
|
||||
|
||||
elif flag[0] == "-ew" or flag[0] == "--evaluation_worker":
|
||||
"""
|
||||
-ew, --evaluation_worker: Once selected, check that an evaluation
|
||||
worker is created. e.g. by checking that it's
|
||||
csv file is created.
|
||||
"""
|
||||
# wait until files created
|
||||
csv_path = get_csv_path(clres=clres)
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||
|
||||
elif flag[0] == "-cp" or flag[0] == "--custom_parameter":
|
||||
"""
|
||||
-cp, --custom_parameter: Once selected, check that the total steps are
|
||||
around the given param with +/- gap.
|
||||
also, check the heat-up param
|
||||
"""
|
||||
# wait until files created
|
||||
csv_path = get_csv_path(clres=clres)
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||
|
||||
# read csv file
|
||||
csv = pd.read_csv(csv_path[0])
|
||||
|
||||
# check heat-up value
|
||||
results = []
|
||||
while csv["In Heatup"].values[-1] == 1:
|
||||
csv = pd.read_csv(csv_path[0])
|
||||
last_step = csv["Total steps"].values
|
||||
time.sleep(1)
|
||||
results.append(last_step[-1])
|
||||
|
||||
assert results[-1] >= Def.Consts.num_hs, \
|
||||
Def.Consts.ASSERT_MSG.format("bigger than " + Def.Consts.num_hs,
|
||||
results[-1])
|
||||
|
||||
elif flag[0] == "-f" or flag[0] == "--framework":
|
||||
"""
|
||||
-f, --framework: Once selected, f = tensorflow or mxnet
|
||||
"""
|
||||
# wait until files created
|
||||
csv_path = get_csv_path(clres=clres)
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||
|
||||
get_reward = is_reward_reached(csv_path=csv_path[0],
|
||||
p_valid_params=p_valid_params,
|
||||
start_time=start_time,
|
||||
time_limit=timeout)
|
||||
|
||||
# check if experiment is working and reached the reward
|
||||
assert get_reward, Def.Consts.ASSERT_MSG.format(
|
||||
"Doesn't reached the reward", get_reward)
|
||||
|
||||
# check if there is no exception
|
||||
assert not find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.LOG_ERROR)
|
||||
|
||||
ret_val = process.poll()
|
||||
assert ret_val is None, Def.Consts.ASSERT_MSG.format("None", ret_val)
|
||||
|
||||
elif flag[0] == "-crd" or flag[0] == "--checkpoint_restore_dir":
|
||||
|
||||
"""
|
||||
-crd, --checkpoint_restore_dir: Once selected alone, check that can't
|
||||
restore checkpoint dir (negative test).
|
||||
"""
|
||||
# wait until files created
|
||||
csv_path = get_csv_path(clres=clres)
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||
assert find_string_in_logs(log_path=clres.stdout.name,
|
||||
str=Def.Consts.NO_CHECKPOINT), \
|
||||
Def.Consts.ASSERT_MSG.format(Def.Consts.NO_CHECKPOINT, "Not found")
|
||||
|
||||
elif flag[0] == "--seed":
|
||||
"""
|
||||
--seed: Once selected, check logs of process list if all are the same
|
||||
results.
|
||||
"""
|
||||
lst_csv = []
|
||||
# wait until files created
|
||||
csv_path = get_csv_path(clres=clres, extra_tries=10)
|
||||
|
||||
assert len(csv_path) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("paths are not found", csv_path)
|
||||
|
||||
assert int(flag[1]) == len(csv_path), Def.Consts.ASSERT_MSG. \
|
||||
format(len(csv_path), int(flag[1]))
|
||||
|
||||
# wait for getting results in csv's
|
||||
for i in range(len(csv_path)):
|
||||
|
||||
lines_in_file = pd.read_csv(csv_path[i])
|
||||
while len(lines_in_file['Episode #'].values) < 100 and \
|
||||
time.time() - start_time < Def.TimeOuts.test_time_limit:
|
||||
lines_in_file = pd.read_csv(csv_path[i])
|
||||
time.sleep(1)
|
||||
|
||||
lst_csv.append(pd.read_csv(csv_path[i],
|
||||
nrows=Def.Consts.N_csv_lines))
|
||||
|
||||
assert len(lst_csv) > 1, Def.Consts.ASSERT_MSG.format("> 1",
|
||||
len(lst_csv))
|
||||
|
||||
df1 = lst_csv[0]
|
||||
for df in lst_csv[1:]:
|
||||
assert list(df1['Training Iter'].values) == list(
|
||||
df['Training Iter'].values)
|
||||
|
||||
assert list(df1['ER #Transitions'].values) == list(
|
||||
df['ER #Transitions'].values)
|
||||
|
||||
assert list(df1['Total steps'].values) == list(
|
||||
df['Total steps'].values)
|
||||
|
||||
elif flag[0] == "-c" or flag[0] == "--use_cpu":
|
||||
pass
|
||||
|
||||
@@ -36,56 +36,112 @@ class Definitions:
|
||||
cp = "custom_parameter"
|
||||
seed = "seed"
|
||||
dccp = "distributed_coach_config_path"
|
||||
enw = "num_workers"
|
||||
fw_ten = "framework_tensorflow"
|
||||
fw_mx = "framework_mxnet"
|
||||
et = "rl_coach.environments.gym_environment:Atari"
|
||||
|
||||
"""
|
||||
Arguments that can be tested for python coach command
|
||||
** None = Flag - no need for string or int
|
||||
** {} = Add format for this parameter
|
||||
** 1 parameter = Flag - no need for string or int
|
||||
** 2 parameters = add value for the selected flag
|
||||
"""
|
||||
cmd_args = {
|
||||
# '-l': None,
|
||||
# '-e': '{}',
|
||||
# '-r': None,
|
||||
# '-n': '{' + enw + '}',
|
||||
# '-c': None,
|
||||
# '-ew': None,
|
||||
'--play': None,
|
||||
'--evaluate': None,
|
||||
# '-v': None,
|
||||
# '-tfv': None,
|
||||
'--nocolor': None,
|
||||
# '-s': '{' + css + '}',
|
||||
# '-crd': '{' + crd + '}',
|
||||
'-dg': None,
|
||||
'-dm': None,
|
||||
# '-et': '{' + et + '}',
|
||||
# '-ept': '{' + ept + '}',
|
||||
# '-lvl': '{level}',
|
||||
# '-cp': '{' + cp + '}',
|
||||
'--print_networks_summary': None,
|
||||
'-tb': None,
|
||||
'-ns': None,
|
||||
'-d': None,
|
||||
# '--seed': '{' + seed + '}',
|
||||
'-onnx': None,
|
||||
'-dc': None,
|
||||
# '-dcp': '{' + dccp + '}',
|
||||
'-asc': None,
|
||||
'--dump_worker_logs': None,
|
||||
}
|
||||
|
||||
cmd_args = [
|
||||
['-ew'],
|
||||
['--play'],
|
||||
['--evaluate'],
|
||||
['-f', fw_ten],
|
||||
['--nocolor'],
|
||||
['-s', css],
|
||||
# ['-crd', crd], # Tested in checkpoint test
|
||||
['-dg'],
|
||||
['-dm'],
|
||||
['-cp', cp],
|
||||
['--print_networks_summary'],
|
||||
['-tb'],
|
||||
['-ns'],
|
||||
['-onnx'],
|
||||
['-asc'],
|
||||
['--dump_worker_logs'],
|
||||
# ['-et', et],
|
||||
# '-lvl': '{level}', # TODO: Add test validation on args_utils
|
||||
# '-e': '{}', # TODO: Add test validation on args_utils
|
||||
# '-l': None, # TODO: Add test validation on args_utils
|
||||
# '-c': None, # TODO: Add test validation using nvidia-smi
|
||||
# '-v': None, # TODO: Add test validation on args_utils
|
||||
# '--seed': '{' + seed + '}', # DONE - new test added
|
||||
# '-dc': None, # TODO: Add test validation on args_utils
|
||||
# '-dcp': '{}' # TODO: Add test validation on args_utils
|
||||
# ['-n', enw], # Duplicated arg test
|
||||
# ['-d'], # Arg can't be automated - no GUI in the CI
|
||||
# '-r': None, # No automation test
|
||||
# '-tfv': None, # No automation test
|
||||
# '-ept': '{' + ept + '}', # No automation test - not supported
|
||||
]
|
||||
|
||||
class Presets:
|
||||
# Preset list for testing the flags/ arguments of python coach command
|
||||
args_test = [
|
||||
"CartPole_A3C",
|
||||
# "CartPole_NEC",
|
||||
]
|
||||
|
||||
# Preset list for mxnet framework testing
|
||||
mxnet_args_test = [
|
||||
"CartPole_DQN"
|
||||
]
|
||||
|
||||
# Preset for testing seed argument
|
||||
seed_args_test = [
|
||||
"Atari_A3C",
|
||||
"Atari_A3C_LSTM",
|
||||
"Atari_Bootstrapped_DQN",
|
||||
"Atari_C51",
|
||||
"Atari_DDQN",
|
||||
"Atari_DQN_with_PER",
|
||||
"Atari_DQN",
|
||||
"Atari_DQN_with_PER",
|
||||
"Atari_Dueling_DDQN",
|
||||
"Atari_Dueling_DDQN_with_PER_OpenAI",
|
||||
"Atari_NStepQ",
|
||||
"Atari_QR_DQN",
|
||||
"Atari_Rainbow",
|
||||
"Atari_UCB_with_Q_Ensembles",
|
||||
"BitFlip_DQN",
|
||||
"BitFlip_DQN_HER",
|
||||
"CartPole_A3C",
|
||||
"CartPole_ClippedPPO",
|
||||
"CartPole_DFP",
|
||||
"CartPole_DQN",
|
||||
"CartPole_Dueling_DDQN",
|
||||
"CartPole_NStepQ",
|
||||
"CartPole_PAL",
|
||||
"CartPole_PG",
|
||||
"ControlSuite_DDPG",
|
||||
"ExplorationChain_Bootstrapped_DQN",
|
||||
"ExplorationChain_Dueling_DDQN",
|
||||
"ExplorationChain_UCB_Q_ensembles",
|
||||
"Fetch_DDPG_HER_baselines",
|
||||
"InvertedPendulum_PG",
|
||||
"MontezumaRevenge_BC",
|
||||
"Mujoco_A3C",
|
||||
"Mujoco_A3C_LSTM",
|
||||
"Mujoco_ClippedPPO",
|
||||
"Mujoco_DDPG",
|
||||
"Mujoco_NAF",
|
||||
"Mujoco_PPO",
|
||||
"Pendulum_HAC",
|
||||
"Starcraft_CollectMinerals_A3C",
|
||||
"Starcraft_CollectMinerals_Dueling_DDQN",
|
||||
]
|
||||
|
||||
class Path:
|
||||
experiments = "./experiments"
|
||||
tensorboard = "tensorboard"
|
||||
test_dir = "test_dir"
|
||||
gifs = "gifs"
|
||||
videos = "videos"
|
||||
checkpoint = "checkpoint"
|
||||
|
||||
class Consts:
|
||||
ASSERT_MSG = "Expected: {}, Actual: {}."
|
||||
@@ -105,7 +161,17 @@ class Definitions:
|
||||
"These flags can not be used together. For human " \
|
||||
"control, please use the --play flag together with " \
|
||||
"the environment type flag (-et)"
|
||||
NO_CHECKPOINT = "No checkpoint to restore in:"
|
||||
LOG_ERROR = "KeyError:"
|
||||
|
||||
num_hs = 200 # heat-up steps (used for agent custom parameters)
|
||||
|
||||
f_comb = 2 # number of flags in cmd for creating flags combinations
|
||||
|
||||
N_csv_lines = 100 # number of lines to validate on csv file
|
||||
|
||||
class TimeOuts:
|
||||
test_time_limit = 60 * 60
|
||||
wait_for_files = 20
|
||||
wait_for_csv = 240
|
||||
test_run = 60
|
||||
|
||||
@@ -18,7 +18,9 @@
|
||||
import glob
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
from os import path
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
|
||||
|
||||
def print_progress(averaged_rewards, last_num_episodes, start_time, time_limit,
|
||||
@@ -49,20 +51,99 @@ def print_progress(averaged_rewards, last_num_episodes, start_time, time_limit,
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def read_csv_paths(test_path, filename_pattern, read_csv_tries=120):
|
||||
def read_csv_paths(test_path, filename_pattern, read_csv_tries=120,
|
||||
extra_tries=0):
|
||||
"""
|
||||
Return file path once it found
|
||||
:param test_path: test folder path
|
||||
:param filename_pattern: csv file pattern
|
||||
:param read_csv_tries: number of iterations until file found
|
||||
:param extra_tries: add number of extra tries to check after getting all
|
||||
the paths.
|
||||
:return: |string| return csv file path
|
||||
"""
|
||||
csv_paths = []
|
||||
tries_counter = 0
|
||||
while not csv_paths:
|
||||
while not csv_paths or extra_tries > 0:
|
||||
csv_paths = glob.glob(path.join(test_path, '*', filename_pattern))
|
||||
if tries_counter > read_csv_tries:
|
||||
break
|
||||
tries_counter += 1
|
||||
time.sleep(1)
|
||||
tries_counter += 1
|
||||
|
||||
if csv_paths and extra_tries > 0:
|
||||
extra_tries -= 1
|
||||
|
||||
return csv_paths
|
||||
|
||||
|
||||
def get_files_from_dir(dir_path):
|
||||
"""
|
||||
Check if folder has files
|
||||
:param dir_path: |string| folder path
|
||||
:return: |list| return files in folder
|
||||
"""
|
||||
start_time = time.time()
|
||||
entities = None
|
||||
while time.time() - start_time < Def.TimeOuts.wait_for_files:
|
||||
# wait until logs created
|
||||
if os.path.exists(dir_path):
|
||||
entities = os.listdir(dir_path)
|
||||
if len(entities) > 0:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
assert len(entities) > 0, \
|
||||
Def.Consts.ASSERT_MSG.format("num files > 0", len(entities))
|
||||
return entities
|
||||
|
||||
|
||||
def find_string_in_logs(log_path, str, timeout=Def.TimeOuts.wait_for_files,
|
||||
wait_and_find=False):
|
||||
"""
|
||||
Find string into the log file
|
||||
:param log_path: |string| log path
|
||||
:param str: |string| search text
|
||||
:param timeout: |int| timeout for searching on file
|
||||
:param wait_and_find: |bool| true if need to wait until reaching timeout
|
||||
:return: |bool| true if string found in the log file
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
# wait until logs created
|
||||
if os.path.exists(log_path):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
if not os.path.exists(log_path):
|
||||
return False
|
||||
|
||||
with open(log_path, 'r') as fr:
|
||||
if str in fr.read():
|
||||
return True
|
||||
fr.close()
|
||||
|
||||
while time.time() - start_time < Def.TimeOuts.test_time_limit \
|
||||
and wait_and_find:
|
||||
with open(log_path, 'r') as fr:
|
||||
if str in fr.read():
|
||||
return True
|
||||
fr.close()
|
||||
return False
|
||||
|
||||
|
||||
def get_csv_path(clres, tries_for_csv=Def.TimeOuts.wait_for_csv,
|
||||
extra_tries=0):
|
||||
"""
|
||||
Get the csv path with the results - reading csv paths will take some time
|
||||
:param clres: object of files that test is creating
|
||||
:param tries_for_csv: timeout of tires until getting all csv files
|
||||
:param extra_tries: add number of extra tries to check after getting all
|
||||
the paths.
|
||||
:return: |list| csv path
|
||||
"""
|
||||
return read_csv_paths(test_path=clres.exp_path,
|
||||
filename_pattern=clres.fn_pattern,
|
||||
read_csv_tries=tries_for_csv,
|
||||
extra_tries=extra_tries)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user