1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

tests: Removed mxnet from functional tests + minor fix on rewards (#362)

* ci: change workflow

* changed timeout

* fix function reach reward

* print logs

* removing mxnet

* res'
This commit is contained in:
anabwan
2019-06-27 18:52:29 +03:00
committed by GitHub
parent 30c64d0656
commit a576ab5659
4 changed files with 28 additions and 24 deletions

View File

@@ -20,12 +20,12 @@ import time
import pytest
import signal
import tempfile
import numpy as np
import pandas as pd
import rl_coach.tests.utils.args_utils as a_utils
import rl_coach.tests.utils.test_utils as test_utils
import rl_coach.tests.utils.presets_utils as p_utils
from rl_coach import checkpoint
from rl_coach.logger import screen
from rl_coach.tests.utils.definitions import Definitions as Def
@@ -55,16 +55,14 @@ def test_get_checkpoint_state():
@pytest.mark.functional_test
@pytest.mark.parametrize("framework", ["mxnet", "tensorflow"])
@pytest.mark.parametrize("framework", ["tensorflow"])
def test_restore_checkpoint(preset_args, clres, framework,
start_time=time.time(),
timeout=Def.TimeOuts.test_time_limit):
"""
Create checkpoints and restore them in second run.
:param preset_args: all preset that can be tested for argument tests
:param clres: logs and csv files
:param framework: name of the test framework
:param start_time: test started time
:param timeout: max time for test
"""
@@ -79,7 +77,7 @@ def test_restore_checkpoint(preset_args, clres, framework,
'python3', 'rl_coach/coach.py',
'-p', '{}'.format(preset_args),
'-e', '{}'.format("ExpName_" + preset_args),
'--seed', '{}'.format(42),
'--seed', '{}'.format(4),
'-f', '{}'.format(framework),
]
@@ -90,6 +88,8 @@ def test_restore_checkpoint(preset_args, clres, framework,
return p
start_time=time.time()
if framework == "mxnet":
# update preset name - for mxnet framework we are using *_DQN
preset_args = Def.Presets.mxnet_args_test[0]
@@ -113,9 +113,12 @@ def test_restore_checkpoint(preset_args, clres, framework,
if os.path.exists(checkpoint_test_dir):
shutil.rmtree(checkpoint_test_dir)
assert a_utils.is_reward_reached(csv_path=csv_list[0],
p_valid_params=p_valid_params,
start_time=start_time, time_limit=timeout)
res = a_utils.is_reward_reached(csv_path=csv_list[0],
p_valid_params=p_valid_params,
start_time=start_time, time_limit=timeout)
if not res:
screen.error(open(clres.stdout.name).read(), crash=False)
assert False
entities = a_utils.get_files_from_dir(checkpoint_dir)