mirror of
https://github.com/gryf/coach.git
synced 2026-02-15 05:25:55 +01:00
pre-release 0.10.0
This commit is contained in:
0
rl_coach/tests/graph_managers/__init__.py
Normal file
0
rl_coach/tests/graph_managers/__init__.py
Normal file
52
rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py
Normal file
52
rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
import tensorflow as tf
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters
|
||||
from rl_coach.utils import get_open_port
|
||||
from multiprocessing import Process
|
||||
from tensorflow import logging
|
||||
import pytest
|
||||
logging.set_verbosity(logging.INFO)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_basic_rl_graph_manager_with_pong_a3c():
|
||||
tf.reset_default_graph()
|
||||
from rl_coach.presets.Atari_A3C import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.env_params.level = "PongDeterministic-v4"
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_basic_rl_graph_manager_with_pong_nec():
|
||||
tf.reset_default_graph()
|
||||
from rl_coach.presets.Atari_NEC import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.env_params.level = "PongDeterministic-v4"
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_basic_rl_graph_manager_with_cartpole_dqn():
|
||||
tf.reset_default_graph()
|
||||
from rl_coach.presets.CartPole_DQN import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
# test_basic_rl_graph_manager_with_pong_a3c()
|
||||
# test_basic_rl_graph_manager_with_ant_a3c()
|
||||
# test_basic_rl_graph_manager_with_pong_nec()
|
||||
# test_basic_rl_graph_manager_with_cartpole_dqn()
|
||||
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
|
||||
#test_basic_rl_graph_manager_with_doom_basic_dqn()
|
||||
Reference in New Issue
Block a user