mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
tests: fixed failed tests - stabling CI (#298)
* tests: stabling CI * tests: fix failed tests - stabling CI * fix get csv files. - fixed seed test * fix clres on conftest - now can modify paths during test run. - this fixed the mxnet checkpoint test * tests: fix comments
This commit is contained in:
@@ -82,9 +82,25 @@ def clres(request):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, csv, log, pattern):
|
def __init__(self, csv, log, pattern):
|
||||||
self.exp_path = csv
|
self.exp_path = csv
|
||||||
self.stdout = log
|
self.stdout = open(log, 'w')
|
||||||
self.fn_pattern = pattern
|
self.fn_pattern = pattern
|
||||||
|
|
||||||
|
@property
|
||||||
|
def experiment_path(self):
|
||||||
|
return self.exp_path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stdout_path(self):
|
||||||
|
return self.stdout
|
||||||
|
|
||||||
|
@experiment_path.setter
|
||||||
|
def experiment_path(self, val):
|
||||||
|
self.exp_path = val
|
||||||
|
|
||||||
|
@stdout_path.setter
|
||||||
|
def stdout_path(self, val):
|
||||||
|
self.stdout = open(val, 'w')
|
||||||
|
|
||||||
# get preset name from test request params
|
# get preset name from test request params
|
||||||
idx = 0 if 'preset' in list(request.node.funcargs.items())[0][0] else 1
|
idx = 0 if 'preset' in list(request.node.funcargs.items())[0][0] else 1
|
||||||
p_name = list(request.node.funcargs.items())[idx][1]
|
p_name = list(request.node.funcargs.items())[idx][1]
|
||||||
@@ -99,10 +115,9 @@ def clres(request):
|
|||||||
|
|
||||||
# get the stdout for logs results
|
# get the stdout for logs results
|
||||||
log_file_name = 'test_log_{}.txt'.format(p_name)
|
log_file_name = 'test_log_{}.txt'.format(p_name)
|
||||||
stdout = open(log_file_name, 'w')
|
|
||||||
fn_pattern = '*.csv' if p_valid_params.num_workers > 1 else 'worker_0*.csv'
|
fn_pattern = '*.csv' if p_valid_params.num_workers > 1 else 'worker_0*.csv'
|
||||||
|
|
||||||
res = CreateCsvLog(test_path, stdout, fn_pattern)
|
res = CreateCsvLog(test_path, log_file_name, fn_pattern)
|
||||||
|
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
|
|||||||
@@ -91,7 +91,13 @@ def test_restore_checkpoint(preset_args, clres, framework,
|
|||||||
return p
|
return p
|
||||||
|
|
||||||
if framework == "mxnet":
|
if framework == "mxnet":
|
||||||
|
# update preset name - for mxnet framework we are using *_DQN
|
||||||
preset_args = Def.Presets.mxnet_args_test[0]
|
preset_args = Def.Presets.mxnet_args_test[0]
|
||||||
|
# update logs paths
|
||||||
|
test_name = 'ExpName_{}'.format(preset_args)
|
||||||
|
test_path = os.path.join(Def.Path.experiments, test_name)
|
||||||
|
clres.experiment_path = test_path
|
||||||
|
clres.stdout_path = 'test_log_{}.txt'.format(preset_args)
|
||||||
|
|
||||||
p_valid_params = p_utils.validation_params(preset_args)
|
p_valid_params = p_utils.validation_params(preset_args)
|
||||||
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])
|
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])
|
||||||
|
|||||||
@@ -453,10 +453,10 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
|||||||
num_expected_files=int(flag[1]))
|
num_expected_files=int(flag[1]))
|
||||||
|
|
||||||
assert len(csv_path) > 0, \
|
assert len(csv_path) > 0, \
|
||||||
Def.Consts.ASSERT_MSG.format("paths are not found", csv_path)
|
Def.Consts.ASSERT_MSG.format("paths are not found", str(csv_path))
|
||||||
|
|
||||||
assert int(flag[1]) == len(csv_path), Def.Consts.ASSERT_MSG. \
|
assert int(flag[1]) == len(csv_path), Def.Consts.ASSERT_MSG. \
|
||||||
format(len(csv_path), int(flag[1]))
|
format(int(flag[1]), len(csv_path))
|
||||||
|
|
||||||
# wait for getting results in csv's
|
# wait for getting results in csv's
|
||||||
for i in range(len(csv_path)):
|
for i in range(len(csv_path)):
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ def print_progress(averaged_rewards, last_num_episodes, start_time, time_limit,
|
|||||||
:param start_time: start time of test
|
:param start_time: start time of test
|
||||||
:param time_limit: time out of test
|
:param time_limit: time out of test
|
||||||
:param p_valid_params: preset validation parameters
|
:param p_valid_params: preset validation parameters
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
max_episodes_to_archive = p_valid_params.max_episodes_to_achieve_reward
|
max_episodes_to_archive = p_valid_params.max_episodes_to_achieve_reward
|
||||||
min_reward = p_valid_params.min_reward_threshold
|
min_reward = p_valid_params.min_reward_threshold
|
||||||
@@ -55,30 +54,36 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=120,
|
|||||||
extra_tries=0, num_expected_files=None):
|
extra_tries=0, num_expected_files=None):
|
||||||
"""
|
"""
|
||||||
Return file path once it found
|
Return file path once it found
|
||||||
:param test_path: test folder path
|
:param test_path: |string| test folder path
|
||||||
:param filename_pattern: csv file pattern
|
:param filename_pattern: |string| csv file pattern
|
||||||
:param read_csv_tries: number of iterations until file found
|
:param read_csv_tries: |int| number of iterations until file found
|
||||||
:param extra_tries: add number of extra tries to check after getting all
|
:param extra_tries: |int| add number of extra tries to check after getting
|
||||||
the paths.
|
all the paths.
|
||||||
:param num_expected_files: find all expected file in experiment folder.
|
:param num_expected_files: find all expected file in experiment folder.
|
||||||
:return: |string| return csv file path
|
:return: |string| return csv file path
|
||||||
"""
|
"""
|
||||||
csv_paths = []
|
csv_paths = []
|
||||||
tries_counter = 0
|
tries_counter = 0
|
||||||
while not csv_paths or extra_tries > 0:
|
|
||||||
csv_paths = glob.glob(path.join(test_path, '*', filename_pattern))
|
|
||||||
if tries_counter > read_csv_tries:
|
|
||||||
break
|
|
||||||
|
|
||||||
if num_expected_files and num_expected_files == len(csv_paths):
|
if isinstance(extra_tries, int) and extra_tries >= 0:
|
||||||
|
read_csv_tries += extra_tries
|
||||||
|
|
||||||
|
while tries_counter < read_csv_tries:
|
||||||
|
csv_paths = glob.glob(path.join(test_path, '*', filename_pattern))
|
||||||
|
|
||||||
|
if num_expected_files:
|
||||||
|
if num_expected_files == len(csv_paths):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
time.sleep(1)
|
||||||
|
tries_counter += 1
|
||||||
|
continue
|
||||||
|
elif csv_paths:
|
||||||
break
|
break
|
||||||
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
tries_counter += 1
|
tries_counter += 1
|
||||||
|
|
||||||
if csv_paths and extra_tries > 0:
|
|
||||||
extra_tries -= 1
|
|
||||||
|
|
||||||
return csv_paths
|
return csv_paths
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user