mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)
This commit is contained in:
committed by
Gal Leibovich
parent
559969d3dd
commit
87a7848b0a
@@ -37,7 +37,7 @@ from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, Ten
|
||||
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
|
||||
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
||||
|
||||
|
||||
@@ -45,9 +45,40 @@ class GeneralMxnetNetwork(MxnetArchitecture):
|
||||
"""
|
||||
A generalized version of all possible networks implemented using mxnet.
|
||||
"""
|
||||
@staticmethod
|
||||
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork':
|
||||
"""
|
||||
Construct a network class using the provided variable scope and on requested devices
|
||||
:param variable_scope: string specifying variable scope under which to create network variables
|
||||
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
|
||||
:param args: all other arguments for class initializer
|
||||
:param kwargs: all other keyword arguments for class initializer
|
||||
:return: a GeneralTensorFlowNetwork object
|
||||
"""
|
||||
return GeneralMxnetNetwork(*args, devices=[GeneralMxnetNetwork._mx_device(d) for d in devices], **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _mx_device(device: Union[str, Device]) -> mx.Context:
|
||||
"""
|
||||
Convert device to tensorflow-specific device representation
|
||||
:param device: either a specific string (used in distributed mode) which is returned without
|
||||
any change or a Device type
|
||||
:return: tensorflow-specific string for device
|
||||
"""
|
||||
if isinstance(device, Device):
|
||||
if device.device_type == DeviceType.CPU:
|
||||
return mx.cpu()
|
||||
elif device.device_type == DeviceType.GPU:
|
||||
return mx.gpu(device.index)
|
||||
else:
|
||||
raise ValueError("Invalid device_type: {}".format(device.device_type))
|
||||
else:
|
||||
raise ValueError("Invalid device instance type: {}".format(type(device)))
|
||||
|
||||
def __init__(self,
|
||||
agent_parameters: AgentParameters,
|
||||
spaces: SpacesDefinition,
|
||||
devices: List[mx.Context],
|
||||
name: str,
|
||||
global_network=None,
|
||||
network_is_local: bool=True,
|
||||
@@ -55,6 +86,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
|
||||
"""
|
||||
:param agent_parameters: the agent parameters
|
||||
:param spaces: the spaces definition of the agent
|
||||
:param devices: list of devices to run the network on
|
||||
:param name: the name of the network
|
||||
:param global_network: the global network replica that is shared between all the workers
|
||||
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
|
||||
@@ -69,7 +101,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
|
||||
self.num_heads_per_network = len(self.network_parameters.heads_parameters)
|
||||
self.num_networks = 1
|
||||
|
||||
super().__init__(agent_parameters, spaces, name, global_network,
|
||||
super().__init__(agent_parameters, spaces, devices, name, global_network,
|
||||
network_is_local, network_is_trainable)
|
||||
|
||||
def construct_model(self):
|
||||
|
||||
Reference in New Issue
Block a user