1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

tests: Removed mxnet from functional tests + minor fix on rewards (#362)

* ci: change workflow

* changed timeout

* fix function reach reward

* print logs

* removing mxnet

* res'
This commit is contained in:
anabwan
2019-06-27 18:52:29 +03:00
committed by GitHub
parent 30c64d0656
commit a576ab5659
4 changed files with 28 additions and 24 deletions

View File

@@ -731,7 +731,6 @@ workflows:
- functional_tests: - functional_tests:
requires: requires:
- build_base - build_base
- integration_tests
- functional_test_doom: - functional_test_doom:
requires: requires:
- build_doom_env - build_doom_env

View File

@@ -20,12 +20,12 @@ import time
import pytest import pytest
import signal import signal
import tempfile import tempfile
import numpy as np
import pandas as pd import pandas as pd
import rl_coach.tests.utils.args_utils as a_utils import rl_coach.tests.utils.args_utils as a_utils
import rl_coach.tests.utils.test_utils as test_utils import rl_coach.tests.utils.test_utils as test_utils
import rl_coach.tests.utils.presets_utils as p_utils import rl_coach.tests.utils.presets_utils as p_utils
from rl_coach import checkpoint from rl_coach import checkpoint
from rl_coach.logger import screen
from rl_coach.tests.utils.definitions import Definitions as Def from rl_coach.tests.utils.definitions import Definitions as Def
@@ -55,16 +55,14 @@ def test_get_checkpoint_state():
@pytest.mark.functional_test @pytest.mark.functional_test
@pytest.mark.parametrize("framework", ["mxnet", "tensorflow"]) @pytest.mark.parametrize("framework", ["tensorflow"])
def test_restore_checkpoint(preset_args, clres, framework, def test_restore_checkpoint(preset_args, clres, framework,
start_time=time.time(),
timeout=Def.TimeOuts.test_time_limit): timeout=Def.TimeOuts.test_time_limit):
""" """
Create checkpoints and restore them in second run. Create checkpoints and restore them in second run.
:param preset_args: all preset that can be tested for argument tests :param preset_args: all preset that can be tested for argument tests
:param clres: logs and csv files :param clres: logs and csv files
:param framework: name of the test framework :param framework: name of the test framework
:param start_time: test started time
:param timeout: max time for test :param timeout: max time for test
""" """
@@ -79,7 +77,7 @@ def test_restore_checkpoint(preset_args, clres, framework,
'python3', 'rl_coach/coach.py', 'python3', 'rl_coach/coach.py',
'-p', '{}'.format(preset_args), '-p', '{}'.format(preset_args),
'-e', '{}'.format("ExpName_" + preset_args), '-e', '{}'.format("ExpName_" + preset_args),
'--seed', '{}'.format(42), '--seed', '{}'.format(4),
'-f', '{}'.format(framework), '-f', '{}'.format(framework),
] ]
@@ -90,6 +88,8 @@ def test_restore_checkpoint(preset_args, clres, framework,
return p return p
start_time=time.time()
if framework == "mxnet": if framework == "mxnet":
# update preset name - for mxnet framework we are using *_DQN # update preset name - for mxnet framework we are using *_DQN
preset_args = Def.Presets.mxnet_args_test[0] preset_args = Def.Presets.mxnet_args_test[0]
@@ -113,9 +113,12 @@ def test_restore_checkpoint(preset_args, clres, framework,
if os.path.exists(checkpoint_test_dir): if os.path.exists(checkpoint_test_dir):
shutil.rmtree(checkpoint_test_dir) shutil.rmtree(checkpoint_test_dir)
assert a_utils.is_reward_reached(csv_path=csv_list[0], res = a_utils.is_reward_reached(csv_path=csv_list[0],
p_valid_params=p_valid_params, p_valid_params=p_valid_params,
start_time=start_time, time_limit=timeout) start_time=start_time, time_limit=timeout)
if not res:
screen.error(open(clres.stdout.name).read(), crash=False)
assert False
entities = a_utils.get_files_from_dir(checkpoint_dir) entities = a_utils.get_files_from_dir(checkpoint_dir)

View File

@@ -117,32 +117,33 @@ def is_reward_reached(csv_path, p_valid_params, start_time, time_limit):
last_num_episodes = 0 last_num_episodes = 0
csv = None csv = None
reward_reached = False reward_reached = False
reward_str = 'Evaluation Reward'
while csv is None or (csv['Episode #'].values[-1] while csv is None or (csv[csv.columns[0]].values[
< p_valid_params.max_episodes_to_achieve_reward and -1] < p_valid_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit):
time.time() - start_time < time_limit): try:
csv = pd.read_csv(csv_path)
csv = pd.read_csv(csv_path) except:
# sometimes the csv is being written at the same time we are
if 'Evaluation Reward' not in csv.keys(): # trying to read it. no problem -> try again
continue continue
rewards = csv['Evaluation Reward'].values if reward_str not in csv.keys():
continue
rewards = csv[reward_str].values
rewards = rewards[~np.isnan(rewards)] rewards = rewards[~np.isnan(rewards)]
if len(rewards) >= 1:
averaged_rewards = np.convolve(rewards, np.ones(
min(len(rewards), win_size)) / win_size, mode='valid')
if len(rewards) >= 1:
averaged_rewards = np.convolve(rewards, np.ones(min(len(rewards), win_size)) / win_size, mode='valid')
else: else:
# May be in heat-up steps
time.sleep(1) time.sleep(1)
continue continue
if csv['Episode #'].shape[0] - last_num_episodes <= 0: if csv[csv.columns[0]].shape[0] - last_num_episodes <= 0:
continue continue
last_num_episodes = csv['Episode #'].values[-1] last_num_episodes = csv[csv.columns[0]].values[-1]
# check if reward is enough # check if reward is enough
if np.any(averaged_rewards >= p_valid_params.min_reward_threshold): if np.any(averaged_rewards >= p_valid_params.min_reward_threshold):
@@ -408,6 +409,7 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
csv_path = get_csv_path(clres=clres) csv_path = get_csv_path(clres=clres)
assert len(csv_path) > 0, \ assert len(csv_path) > 0, \
Def.Consts.ASSERT_MSG.format("path not found", csv_path) Def.Consts.ASSERT_MSG.format("path not found", csv_path)
time.sleep(5)
get_reward = is_reward_reached(csv_path=csv_path[0], get_reward = is_reward_reached(csv_path=csv_path[0],
p_valid_params=p_valid_params, p_valid_params=p_valid_params,

View File

@@ -139,7 +139,7 @@ class Definitions:
N_csv_lines = 100 # number of lines to validate on csv file N_csv_lines = 100 # number of lines to validate on csv file
class TimeOuts: class TimeOuts:
test_time_limit = 60 * 60 test_time_limit = 30 * 60
wait_for_files = 20 wait_for_files = 20
wait_for_csv = 240 wait_for_csv = 240
test_run = 60 test_run = 60