mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
tests: Removed mxnet from functional tests + minor fix on rewards (#362)
* ci: change workflow * changed timeout * fix function reach reward * print logs * removing mxnet * res'
This commit is contained in:
@@ -731,7 +731,6 @@ workflows:
|
|||||||
- functional_tests:
|
- functional_tests:
|
||||||
requires:
|
requires:
|
||||||
- build_base
|
- build_base
|
||||||
- integration_tests
|
|
||||||
- functional_test_doom:
|
- functional_test_doom:
|
||||||
requires:
|
requires:
|
||||||
- build_doom_env
|
- build_doom_env
|
||||||
|
|||||||
@@ -20,12 +20,12 @@ import time
|
|||||||
import pytest
|
import pytest
|
||||||
import signal
|
import signal
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import rl_coach.tests.utils.args_utils as a_utils
|
import rl_coach.tests.utils.args_utils as a_utils
|
||||||
import rl_coach.tests.utils.test_utils as test_utils
|
import rl_coach.tests.utils.test_utils as test_utils
|
||||||
import rl_coach.tests.utils.presets_utils as p_utils
|
import rl_coach.tests.utils.presets_utils as p_utils
|
||||||
from rl_coach import checkpoint
|
from rl_coach import checkpoint
|
||||||
|
from rl_coach.logger import screen
|
||||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||||
|
|
||||||
|
|
||||||
@@ -55,16 +55,14 @@ def test_get_checkpoint_state():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.functional_test
|
@pytest.mark.functional_test
|
||||||
@pytest.mark.parametrize("framework", ["mxnet", "tensorflow"])
|
@pytest.mark.parametrize("framework", ["tensorflow"])
|
||||||
def test_restore_checkpoint(preset_args, clres, framework,
|
def test_restore_checkpoint(preset_args, clres, framework,
|
||||||
start_time=time.time(),
|
|
||||||
timeout=Def.TimeOuts.test_time_limit):
|
timeout=Def.TimeOuts.test_time_limit):
|
||||||
"""
|
"""
|
||||||
Create checkpoints and restore them in second run.
|
Create checkpoints and restore them in second run.
|
||||||
:param preset_args: all preset that can be tested for argument tests
|
:param preset_args: all preset that can be tested for argument tests
|
||||||
:param clres: logs and csv files
|
:param clres: logs and csv files
|
||||||
:param framework: name of the test framework
|
:param framework: name of the test framework
|
||||||
:param start_time: test started time
|
|
||||||
:param timeout: max time for test
|
:param timeout: max time for test
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -79,7 +77,7 @@ def test_restore_checkpoint(preset_args, clres, framework,
|
|||||||
'python3', 'rl_coach/coach.py',
|
'python3', 'rl_coach/coach.py',
|
||||||
'-p', '{}'.format(preset_args),
|
'-p', '{}'.format(preset_args),
|
||||||
'-e', '{}'.format("ExpName_" + preset_args),
|
'-e', '{}'.format("ExpName_" + preset_args),
|
||||||
'--seed', '{}'.format(42),
|
'--seed', '{}'.format(4),
|
||||||
'-f', '{}'.format(framework),
|
'-f', '{}'.format(framework),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -90,6 +88,8 @@ def test_restore_checkpoint(preset_args, clres, framework,
|
|||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
start_time=time.time()
|
||||||
|
|
||||||
if framework == "mxnet":
|
if framework == "mxnet":
|
||||||
# update preset name - for mxnet framework we are using *_DQN
|
# update preset name - for mxnet framework we are using *_DQN
|
||||||
preset_args = Def.Presets.mxnet_args_test[0]
|
preset_args = Def.Presets.mxnet_args_test[0]
|
||||||
@@ -113,9 +113,12 @@ def test_restore_checkpoint(preset_args, clres, framework,
|
|||||||
if os.path.exists(checkpoint_test_dir):
|
if os.path.exists(checkpoint_test_dir):
|
||||||
shutil.rmtree(checkpoint_test_dir)
|
shutil.rmtree(checkpoint_test_dir)
|
||||||
|
|
||||||
assert a_utils.is_reward_reached(csv_path=csv_list[0],
|
res = a_utils.is_reward_reached(csv_path=csv_list[0],
|
||||||
p_valid_params=p_valid_params,
|
p_valid_params=p_valid_params,
|
||||||
start_time=start_time, time_limit=timeout)
|
start_time=start_time, time_limit=timeout)
|
||||||
|
if not res:
|
||||||
|
screen.error(open(clres.stdout.name).read(), crash=False)
|
||||||
|
assert False
|
||||||
|
|
||||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||||
|
|
||||||
|
|||||||
@@ -117,32 +117,33 @@ def is_reward_reached(csv_path, p_valid_params, start_time, time_limit):
|
|||||||
last_num_episodes = 0
|
last_num_episodes = 0
|
||||||
csv = None
|
csv = None
|
||||||
reward_reached = False
|
reward_reached = False
|
||||||
|
reward_str = 'Evaluation Reward'
|
||||||
|
|
||||||
while csv is None or (csv['Episode #'].values[-1]
|
while csv is None or (csv[csv.columns[0]].values[
|
||||||
< p_valid_params.max_episodes_to_achieve_reward and
|
-1] < p_valid_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit):
|
||||||
time.time() - start_time < time_limit):
|
try:
|
||||||
|
csv = pd.read_csv(csv_path)
|
||||||
csv = pd.read_csv(csv_path)
|
except:
|
||||||
|
# sometimes the csv is being written at the same time we are
|
||||||
if 'Evaluation Reward' not in csv.keys():
|
# trying to read it. no problem -> try again
|
||||||
continue
|
continue
|
||||||
|
|
||||||
rewards = csv['Evaluation Reward'].values
|
if reward_str not in csv.keys():
|
||||||
|
continue
|
||||||
|
|
||||||
|
rewards = csv[reward_str].values
|
||||||
rewards = rewards[~np.isnan(rewards)]
|
rewards = rewards[~np.isnan(rewards)]
|
||||||
if len(rewards) >= 1:
|
|
||||||
averaged_rewards = np.convolve(rewards, np.ones(
|
|
||||||
min(len(rewards), win_size)) / win_size, mode='valid')
|
|
||||||
|
|
||||||
|
if len(rewards) >= 1:
|
||||||
|
averaged_rewards = np.convolve(rewards, np.ones(min(len(rewards), win_size)) / win_size, mode='valid')
|
||||||
else:
|
else:
|
||||||
# May be in heat-up steps
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if csv['Episode #'].shape[0] - last_num_episodes <= 0:
|
if csv[csv.columns[0]].shape[0] - last_num_episodes <= 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
last_num_episodes = csv['Episode #'].values[-1]
|
last_num_episodes = csv[csv.columns[0]].values[-1]
|
||||||
|
|
||||||
# check if reward is enough
|
# check if reward is enough
|
||||||
if np.any(averaged_rewards >= p_valid_params.min_reward_threshold):
|
if np.any(averaged_rewards >= p_valid_params.min_reward_threshold):
|
||||||
@@ -408,6 +409,7 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
|
|||||||
csv_path = get_csv_path(clres=clres)
|
csv_path = get_csv_path(clres=clres)
|
||||||
assert len(csv_path) > 0, \
|
assert len(csv_path) > 0, \
|
||||||
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
Def.Consts.ASSERT_MSG.format("path not found", csv_path)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
get_reward = is_reward_reached(csv_path=csv_path[0],
|
get_reward = is_reward_reached(csv_path=csv_path[0],
|
||||||
p_valid_params=p_valid_params,
|
p_valid_params=p_valid_params,
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class Definitions:
|
|||||||
N_csv_lines = 100 # number of lines to validate on csv file
|
N_csv_lines = 100 # number of lines to validate on csv file
|
||||||
|
|
||||||
class TimeOuts:
|
class TimeOuts:
|
||||||
test_time_limit = 60 * 60
|
test_time_limit = 30 * 60
|
||||||
wait_for_files = 20
|
wait_for_files = 20
|
||||||
wait_for_csv = 240
|
wait_for_csv = 240
|
||||||
test_run = 60
|
test_run = 60
|
||||||
|
|||||||
Reference in New Issue
Block a user