mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48)
NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper.
This commit is contained in:
committed by
Scott Leishman
parent
2046358ab0
commit
95b4fc6888
@@ -29,6 +29,7 @@ from rl_coach.filters.filter import NoInputFilter
|
||||
|
||||
class Frameworks(Enum):
|
||||
tensorflow = "TensorFlow"
|
||||
mxnet = "MXNet"
|
||||
|
||||
|
||||
class EmbedderScheme(Enum):
|
||||
@@ -415,7 +416,7 @@ class AgentParameters(Parameters):
|
||||
|
||||
|
||||
class TaskParameters(Parameters):
|
||||
def __init__(self, framework_type: str='tensorflow', evaluate_only: bool=False, use_cpu: bool=False,
|
||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None):
|
||||
"""
|
||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||
@@ -435,7 +436,7 @@ class TaskParameters(Parameters):
|
||||
|
||||
|
||||
class DistributedTaskParameters(TaskParameters):
|
||||
def __init__(self, framework_type: str, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
||||
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
||||
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
|
||||
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
||||
shared_memory_scratchpad=None, seed=None):
|
||||
|
||||
Reference in New Issue
Block a user