mirror of
https://github.com/gryf/coach.git
synced 2026-02-14 12:55:51 +01:00
Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48)
NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper.
This commit is contained in:
committed by
Scott Leishman
parent
2046358ab0
commit
95b4fc6888
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from rl_coach.base_parameters import TaskParameters
|
||||
from rl_coach.base_parameters import TaskParameters, Frameworks
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
import tensorflow as tf
|
||||
@@ -16,7 +16,7 @@ def test_get_QActionStateValue_predictions():
|
||||
from rl_coach.presets.CartPole_DQN import graph_manager as cartpole_dqn_graph_manager
|
||||
assert cartpole_dqn_graph_manager
|
||||
cartpole_dqn_graph_manager.create_graph(task_parameters=
|
||||
TaskParameters(framework_type="tensorflow",
|
||||
TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test"))
|
||||
cartpole_dqn_graph_manager.improve_steps.num_steps = 1
|
||||
cartpole_dqn_graph_manager.steps_between_evaluation_periods.num_steps = 5
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
import tensorflow as tf
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
|
||||
from rl_coach.utils import get_open_port
|
||||
from multiprocessing import Process
|
||||
from tensorflow import logging
|
||||
@@ -16,7 +16,7 @@ def test_basic_rl_graph_manager_with_pong_a3c():
|
||||
from rl_coach.presets.Atari_A3C import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.env_params.level = "PongDeterministic-v4"
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_basic_rl_graph_manager_with_pong_nec():
|
||||
from rl_coach.presets.Atari_NEC import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.env_params.level = "PongDeterministic-v4"
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
@@ -37,7 +37,7 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
|
||||
tf.reset_default_graph()
|
||||
from rl_coach.presets.CartPole_DQN import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user