1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48)

NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper.
This commit is contained in:
Sina Afrooze
2018-10-30 15:29:34 -07:00
committed by Scott Leishman
parent 2046358ab0
commit 95b4fc6888
8 changed files with 47 additions and 21 deletions

View File

@@ -29,6 +29,7 @@ from rl_coach.filters.filter import NoInputFilter
class Frameworks(Enum):
tensorflow = "TensorFlow"
mxnet = "MXNet"
class EmbedderScheme(Enum):
@@ -415,7 +416,7 @@ class AgentParameters(Parameters):
class TaskParameters(Parameters):
def __init__(self, framework_type: str='tensorflow', evaluate_only: bool=False, use_cpu: bool=False,
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
experiment_path='/tmp', seed=None, checkpoint_save_secs=None):
"""
:param framework_type: deep learning framework type. currently only tensorflow is supported
@@ -435,7 +436,7 @@ class TaskParameters(Parameters):
class DistributedTaskParameters(TaskParameters):
def __init__(self, framework_type: str, parameters_server_hosts: str, worker_hosts: str, job_type: str,
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
shared_memory_scratchpad=None, seed=None):