mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
tests: Removed mxnet from functional tests + minor fix on rewards (#362)
* ci: change workflow * changed timeout * fix function reach reward * print logs * removing mxnet * res'
This commit is contained in:
@@ -20,12 +20,12 @@ import time
|
||||
import pytest
|
||||
import signal
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import rl_coach.tests.utils.args_utils as a_utils
|
||||
import rl_coach.tests.utils.test_utils as test_utils
|
||||
import rl_coach.tests.utils.presets_utils as p_utils
|
||||
from rl_coach import checkpoint
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.tests.utils.definitions import Definitions as Def
|
||||
|
||||
|
||||
@@ -55,16 +55,14 @@ def test_get_checkpoint_state():
|
||||
|
||||
|
||||
@pytest.mark.functional_test
|
||||
@pytest.mark.parametrize("framework", ["mxnet", "tensorflow"])
|
||||
@pytest.mark.parametrize("framework", ["tensorflow"])
|
||||
def test_restore_checkpoint(preset_args, clres, framework,
|
||||
start_time=time.time(),
|
||||
timeout=Def.TimeOuts.test_time_limit):
|
||||
"""
|
||||
Create checkpoints and restore them in second run.
|
||||
:param preset_args: all preset that can be tested for argument tests
|
||||
:param clres: logs and csv files
|
||||
:param framework: name of the test framework
|
||||
:param start_time: test started time
|
||||
:param timeout: max time for test
|
||||
"""
|
||||
|
||||
@@ -79,7 +77,7 @@ def test_restore_checkpoint(preset_args, clres, framework,
|
||||
'python3', 'rl_coach/coach.py',
|
||||
'-p', '{}'.format(preset_args),
|
||||
'-e', '{}'.format("ExpName_" + preset_args),
|
||||
'--seed', '{}'.format(42),
|
||||
'--seed', '{}'.format(4),
|
||||
'-f', '{}'.format(framework),
|
||||
]
|
||||
|
||||
@@ -90,6 +88,8 @@ def test_restore_checkpoint(preset_args, clres, framework,
|
||||
|
||||
return p
|
||||
|
||||
start_time=time.time()
|
||||
|
||||
if framework == "mxnet":
|
||||
# update preset name - for mxnet framework we are using *_DQN
|
||||
preset_args = Def.Presets.mxnet_args_test[0]
|
||||
@@ -113,9 +113,12 @@ def test_restore_checkpoint(preset_args, clres, framework,
|
||||
if os.path.exists(checkpoint_test_dir):
|
||||
shutil.rmtree(checkpoint_test_dir)
|
||||
|
||||
assert a_utils.is_reward_reached(csv_path=csv_list[0],
|
||||
p_valid_params=p_valid_params,
|
||||
start_time=start_time, time_limit=timeout)
|
||||
res = a_utils.is_reward_reached(csv_path=csv_list[0],
|
||||
p_valid_params=p_valid_params,
|
||||
start_time=start_time, time_limit=timeout)
|
||||
if not res:
|
||||
screen.error(open(clres.stdout.name).read(), crash=False)
|
||||
assert False
|
||||
|
||||
entities = a_utils.get_files_from_dir(checkpoint_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user