mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)
This commit is contained in:
committed by
Gal Leibovich
parent
559969d3dd
commit
87a7848b0a
@@ -61,6 +61,35 @@ class RunType(Enum):
|
||||
return self.value
|
||||
|
||||
|
||||
class DeviceType(Enum):
|
||||
CPU = 'cpu'
|
||||
GPU = 'gpu'
|
||||
|
||||
|
||||
class Device(object):
|
||||
def __init__(self, device_type: DeviceType, index: int=0):
|
||||
"""
|
||||
:param device_type: type of device (CPU/GPU)
|
||||
:param index: index of device (only used if device type is GPU)
|
||||
"""
|
||||
self._device_type = device_type
|
||||
self._index = index
|
||||
|
||||
@property
|
||||
def device_type(self):
|
||||
return self._device_type
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
return self._index
|
||||
|
||||
def __str__(self):
|
||||
return "{}{}".format(self._device_type, self._index)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach.
|
||||
# The default value is None, which means the algorithm or preset cannot be used with distributed Coach.
|
||||
class DistributedCoachSynchronizationType(Enum):
|
||||
@@ -520,7 +549,8 @@ class AgentParameters(Parameters):
|
||||
class TaskParameters(Parameters):
|
||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
|
||||
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False,
|
||||
num_gpu: int=1):
|
||||
"""
|
||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||
:param evaluate_only: the task will be used only for evaluating the model
|
||||
@@ -532,7 +562,7 @@ class TaskParameters(Parameters):
|
||||
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
||||
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
|
||||
|
||||
:param num_gpu: number of GPUs to use
|
||||
"""
|
||||
self.framework_type = framework_type
|
||||
self.task_index = 0 # TODO: not really needed
|
||||
@@ -545,6 +575,7 @@ class TaskParameters(Parameters):
|
||||
self.seed = seed
|
||||
self.export_onnx_graph = export_onnx_graph
|
||||
self.apply_stop_condition = apply_stop_condition
|
||||
self.num_gpu = num_gpu
|
||||
|
||||
|
||||
class DistributedTaskParameters(TaskParameters):
|
||||
|
||||
Reference in New Issue
Block a user