1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-02 05:45:45 +01:00

tests: added new setup configuration + test args (#211)

- added utils for future tests and conftest
- added test args
This commit is contained in:
anabwan
2019-02-13 07:43:59 -05:00
committed by GitHub
parent 9d0fed84a3
commit 7253f511ed
10 changed files with 815 additions and 2 deletions

View File

View File

@@ -0,0 +1,354 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Manage all command arguments."""
import os
import re
import signal
import time
import psutil as psutil
from rl_coach.logger import screen
from rl_coach.tests.utils import test_utils
from rl_coach.tests.utils.definitions import Definitions as Def
def collect_preset_for_args():
"""
Collect presets that relevant for args testing only.
This used for testing arguments for specific presets that defined in the
definitions (args_test under Presets).
:return: preset(s) list
"""
for pn in Def.Presets.args_test:
assert pn, Def.Consts.ASSERT_MSG.format("Preset name", pn)
yield pn
def collect_args():
"""
Collect args from the cmd args list - on each test iteration, it will
yield one line (one arg).
:yield: one arg foe each test iteration
"""
for k, v in Def.Flags.cmd_args.items():
cmd = []
cmd.append(k)
if v is not None:
cmd.append(v)
assert cmd, Def.Consts.ASSERT_MSG.format("cmd array", str(cmd))
yield cmd
def add_one_flag_value(flag):
"""
Add value to flag format in order to run the python command with arguments.
:param flag: dict flag
:return: flag with format
"""
if not flag or len(flag) > 2 or len(flag) == 0:
return []
if len(flag) == 1:
return flag
if Def.Flags.css in flag[1]:
flag[1] = 30
elif Def.Flags.crd in flag[1]:
# TODO: check dir of checkpoint
flag[1] = os.path.join(Def.Path.experiments)
elif Def.Flags.et in flag[1]:
# TODO: add valid value
flag[1] = ""
elif Def.Flags.ept in flag[1]:
# TODO: add valid value
flag[1] = ""
elif Def.Flags.cp in flag[1]:
# TODO: add valid value
flag[1] = ""
elif Def.Flags.seed in flag[1]:
flag[1] = 0
elif Def.Flags.dccp in flag[1]:
# TODO: add valid value
flag[1] = ""
return flag
def check_files_in_dir(dir_path):
"""
Check if folder has files
:param dir_path: |string| folder path
:return: |Array| return files in folder
"""
start_time = time.time()
entities = None
while time.time() - start_time < Def.TimeOuts.wait_for_files:
# wait until logs created
if os.path.exists(dir_path):
entities = os.listdir(dir_path)
if len(entities) > 0:
break
time.sleep(1)
assert len(entities) > 0, \
Def.Consts.ASSERT_MSG.format("num files > 0", len(entities))
return entities
def find_string_in_logs(log_path, str):
"""
Find string into the log file
:param log_path: |string| log path
:param str: |string| search text
:return: |bool| true if string found in the log file
"""
start_time = time.time()
while time.time() - start_time < Def.TimeOuts.wait_for_files:
# wait until logs created
if os.path.exists(log_path):
break
time.sleep(1)
if not os.path.exists(log_path):
return False
if str in open(log_path, 'r').read():
return True
return False
def get_csv_path(clres):
"""
Get the csv path with the results - reading csv paths will take some time
:param clres: object of files that test is creating
:return: |Array| csv path
"""
return test_utils.read_csv_paths(test_path=clres.exp_path,
filename_pattern=clres.fn_pattern)
def validate_args_results(flag, clres=None, process=None, start_time=None,
timeout=None):
"""
Validate results of one argument.
:param flag: flag to check
:param clres: object of files paths (results of test experiment)
:param process: process object
:param start_time: start time of the test
:param timeout: timeout of the test- fail test once over the timeout
"""
if flag[0] == "-ns" or flag[0] == "--no-summary":
"""
--no-summary: Once selected, summary lines shouldn't appear in logs
"""
# send CTRL+C to close experiment
process.send_signal(signal.SIGINT)
assert not find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.RESULTS_SORTED), \
Def.Consts.ASSERT_MSG.format("No Result summary",
Def.Consts.RESULTS_SORTED)
assert not find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.TOTAL_RUNTIME), \
Def.Consts.ASSERT_MSG.format("No Total runtime summary",
Def.Consts.TOTAL_RUNTIME)
assert not find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.DISCARD_EXP), \
Def.Consts.ASSERT_MSG.format("No discard message",
Def.Consts.DISCARD_EXP)
elif flag[0] == "-asc" or flag[0] == "--apply_stop_condition":
"""
-asc, --apply_stop_condition: Once selected, coach stopped when
required success rate reached
"""
while time.time() - start_time < timeout:
if find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.REACHED_REQ_ASC):
assert True, Def.Consts.ASSERT_MSG. \
format(Def.Consts.REACHED_REQ_ASC, "Message Not Found")
break
elif flag[0] == "-d" or flag[0] == "--open_dashboard":
"""
-d, --open_dashboard: Once selected, firefox browser will open to show
coach's Dashboard.
"""
proc_id = None
start_time = time.time()
while time.time() - start_time < Def.TimeOuts.wait_for_files:
for proc in psutil.process_iter():
if proc.name() == Def.DASHBOARD_PROC:
assert proc.name() == Def.DASHBOARD_PROC, \
Def.Consts.ASSERT_MSG. format(Def.DASHBOARD_PROC,
proc.name())
proc_id = proc.pid
break
if proc_id:
break
if proc_id:
# kill firefox process
os.kill(proc_id, signal.SIGKILL)
else:
assert False, Def.Consts.ASSERT_MSG.format("Found Firefox process",
proc_id)
elif flag[0] == "--print_networks_summary":
"""
--print_networks_summary: Once selected, agent summary should appear in
stdout.
"""
if find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.INPUT_EMBEDDER):
assert True, Def.Consts.ASSERT_MSG.format(
Def.Consts.INPUT_EMBEDDER, "Not found")
if find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.MIDDLEWARE):
assert True, Def.Consts.ASSERT_MSG.format(
Def.Consts.MIDDLEWARE, "Not found")
if find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.OUTPUT_HEAD):
assert True, Def.Consts.ASSERT_MSG.format(
Def.Consts.OUTPUT_HEAD, "Not found")
elif flag[0] == "-tb" or flag[0] == "--tensorboard":
"""
-tb, --tensorboard: Once selected, a new folder should be created in
experiment folder.
"""
csv_path = get_csv_path(clres)
assert len(csv_path) > 0, \
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
exp_path = os.path.dirname(csv_path[0])
tensorboard_path = os.path.join(exp_path, Def.Path.tensorboard)
assert os.path.isdir(tensorboard_path), \
Def.Consts.ASSERT_MSG.format("tensorboard path", tensorboard_path)
# check if folder contain files
check_files_in_dir(dir_path=tensorboard_path)
elif flag[0] == "-onnx" or flag[0] == "--export_onnx_graph":
"""
-onnx, --export_onnx_graph: Once selected, warning message should
appear, it should be with another flag.
"""
if find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.ONNX_WARNING):
assert True, Def.Consts.ASSERT_MSG.format(
Def.Consts.ONNX_WARNING, "Not found")
elif flag[0] == "-dg" or flag[0] == "--dump_gifs":
"""
-dg, --dump_gifs: Once selected, a new folder should be created in
experiment folder for gifs files.
"""
csv_path = get_csv_path(clres)
assert len(csv_path) > 0, \
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
exp_path = os.path.dirname(csv_path[0])
gifs_path = os.path.join(exp_path, Def.Path.gifs)
# wait until gif folder were created
while time.time() - start_time < timeout:
if os.path.isdir(gifs_path):
assert os.path.isdir(gifs_path), \
Def.Consts.ASSERT_MSG.format("gifs path", gifs_path)
break
# check if folder contain files
check_files_in_dir(dir_path=gifs_path)
# TODO: check if play window is opened
elif flag[0] == "-dm" or flag[0] == "--dump_mp4":
"""
-dm, --dump_mp4: Once selected, a new folder should be created in
experiment folder for videos files.
"""
csv_path = get_csv_path(clres)
assert len(csv_path) > 0, \
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
exp_path = os.path.dirname(csv_path[0])
videos_path = os.path.join(exp_path, Def.Path.videos)
# wait until video folder were created
while time.time() - start_time < timeout:
if os.path.isdir(videos_path):
assert os.path.isdir(videos_path), \
Def.Consts.ASSERT_MSG.format("videos path", videos_path)
break
# check if folder contain files
check_files_in_dir(dir_path=videos_path)
# TODO: check if play window is opened
elif flag[0] == "--nocolor":
"""
--nocolor: Once selected, check if color prefix is replacing the actual
color; example: '## agent: ...'
"""
while time.time() - start_time < timeout:
if find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.COLOR_PREFIX):
assert True, Def.Consts.ASSERT_MSG. \
format(Def.Consts.COLOR_PREFIX, "Color Prefix Not Found")
break
elif flag[0] == "--evaluate":
"""
--evaluate: Once selected, Coach start testing, there is not training.
"""
tries = 5
while time.time() - start_time < timeout and tries > 0:
if find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.TRAINING):
assert False, Def.Consts.ASSERT_MSG.format(
"Training Not Found", Def.Consts.TRAINING)
else:
time.sleep(1)
tries -= 1
assert True, Def.Consts.ASSERT_MSG.format("Training Found",
Def.Consts.TRAINING)
elif flag[0] == "--play":
"""
--play: Once selected alone, warning message should appear, it should
be with another flag.
"""
if find_string_in_logs(log_path=clres.stdout.name,
str=Def.Consts.PLAY_WARNING):
assert True, Def.Consts.ASSERT_MSG.format(
Def.Consts.ONNX_WARNING, "Not found")

View File

@@ -0,0 +1,111 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Definitions file:
It's main functionality are:
1) housing project constants and enums.
2) housing configuration parameters.
3) housing resource paths.
"""
class Definitions:
GROUP_NAME = "rl_coach"
PROCESS_NAME = "coach"
DASHBOARD_PROC = "firefox"
class Flags:
css = "checkpoint_save_secs"
crd = "checkpoint_restore_dir"
et = "environment_type"
ept = "exploration_policy_type"
cp = "custom_parameter"
seed = "seed"
dccp = "distributed_coach_config_path"
"""
Arguments that can be tested for python coach command
** None = Flag - no need for string or int
** {} = Add format for this parameter
"""
cmd_args = {
# '-l': None,
# '-e': '{}',
# '-r': None,
# '-n': '{' + enw + '}',
# '-c': None,
# '-ew': None,
'--play': None,
'--evaluate': None,
# '-v': None,
# '-tfv': None,
'--nocolor': None,
# '-s': '{' + css + '}',
# '-crd': '{' + crd + '}',
'-dg': None,
'-dm': None,
# '-et': '{' + et + '}',
# '-ept': '{' + ept + '}',
# '-lvl': '{level}',
# '-cp': '{' + cp + '}',
'--print_networks_summary': None,
'-tb': None,
'-ns': None,
'-d': None,
# '--seed': '{' + seed + '}',
'-onnx': None,
'-dc': None,
# '-dcp': '{' + dccp + '}',
'-asc': None,
'--dump_worker_logs': None,
}
class Presets:
# Preset list for testing the flags/ arguments of python coach command
args_test = [
"CartPole_A3C",
# "CartPole_NEC",
]
class Path:
experiments = "./experiments"
tensorboard = "tensorboard"
gifs = "gifs"
videos = "videos"
class Consts:
ASSERT_MSG = "Expected: {}, Actual: {}."
RESULTS_SORTED = "Results stored at:"
TOTAL_RUNTIME = "Total runtime:"
DISCARD_EXP = "Do you want to discard the experiment results"
REACHED_REQ_ASC = "Reached required success rate. Exiting."
INPUT_EMBEDDER = "Input Embedder:"
MIDDLEWARE = "Middleware:"
OUTPUT_HEAD = "Output Head:"
ONNX_WARNING = "Exporting ONNX graphs requires setting the " \
"--checkpoint_save_secs flag. The --export_onnx_graph" \
" will have no effect."
COLOR_PREFIX = "## agent: Starting evaluation phase"
TRAINING = "Training - "
PLAY_WARNING = "Both the --preset and the --play flags were set. " \
"These flags can not be used together. For human " \
"control, please use the --play flag together with " \
"the environment type flag (-et)"
class TimeOuts:
test_time_limit = 60 * 60
wait_for_files = 20

View File

@@ -0,0 +1,85 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Manage all preset"""
import os
from importlib import import_module
from rl_coach.tests.utils.definitions import Definitions as Def
def import_preset(preset_name):
"""
Import preset name module from presets directory
:param preset_name: preset name
:return: imported module
"""
return import_module('{}.presets.{}'.format(Def.GROUP_NAME, preset_name))
def validation_params(preset_name):
"""
Validate parameters based on preset name
:param preset_name: preset name
:return: |bool| true if preset has params
"""
return import_preset(preset_name).graph_manager.preset_validation_params
def all_presets():
"""
Get all preset from preset directory
:return: |Array| preset list
"""
return [
f[:-3] for f in os.listdir(os.path.join(Def.GROUP_NAME, 'presets'))
if f[-3:] == '.py' and not f == '__init__.py'
]
def importable(preset_name):
"""
Try to import preset name
:param preset_name: |name| preset name
:return: |bool| true if possible to import preset
"""
try:
import_preset(preset_name)
return True
except BaseException:
return False
def has_test_parameters(preset_name):
"""
Check if preset has parameters
:param preset_name: |string| preset name
:return: |bool| true: if preset have parameters
"""
return bool(validation_params(preset_name).test)
def collect_presets():
"""
Collect all presets in presets directory
:yield: preset name
"""
for preset_name in all_presets():
# if it isn't importable, still include it so we can fail the test
if not importable(preset_name):
yield preset_name
# otherwise, make sure it has test parameters before including it
elif has_test_parameters(preset_name):
yield preset_name

View File

@@ -0,0 +1,68 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Common functionality shared across tests."""
import glob
import sys
import time
from os import path
def print_progress(averaged_rewards, last_num_episodes, start_time, time_limit,
p_valid_params):
"""
Print progress bar for preset run test
:param averaged_rewards: average rewards of test
:param last_num_episodes: last episode number
:param start_time: start time of test
:param time_limit: time out of test
:param p_valid_params: preset validation parameters
:return:
"""
max_episodes_to_archive = p_valid_params.max_episodes_to_achieve_reward
min_reward = p_valid_params.min_reward_threshold
avg_reward = round(averaged_rewards[-1], 1)
percentage = int((100 * last_num_episodes) / max_episodes_to_archive)
cur_time = round(time.time() - start_time, 2)
sys.stdout.write("\rReward: ({}/{})".format(avg_reward, min_reward))
sys.stdout.write(' Time (sec): ({}/{})'.format(cur_time, time_limit))
sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes,
max_episodes_to_archive))
sys.stdout.write(' {}%|{}{}| '
.format(percentage, '#' * int(percentage / 10),
' ' * (10 - int(percentage / 10))))
sys.stdout.flush()
def read_csv_paths(test_path, filename_pattern, read_csv_tries=120):
"""
Return file path once it found
:param test_path: test folder path
:param filename_pattern: csv file pattern
:param read_csv_tries: number of iterations until file found
:return: |string| return csv file path
"""
csv_paths = []
tries_counter = 0
while not csv_paths:
csv_paths = glob.glob(path.join(test_path, '*', filename_pattern))
if tries_counter > read_csv_tries:
break
tries_counter += 1
time.sleep(1)
return csv_paths