1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)

This commit is contained in:
Sina Afrooze
2018-11-22 12:52:22 -08:00
committed by Gal Leibovich
parent 559969d3dd
commit 87a7848b0a
11 changed files with 219 additions and 91 deletions

View File

@@ -37,7 +37,7 @@ from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, Ten
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
@@ -45,9 +45,40 @@ class GeneralMxnetNetwork(MxnetArchitecture):
"""
A generalized version of all possible networks implemented using mxnet.
"""
@staticmethod
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork':
"""
Construct a network class using the provided variable scope and on requested devices
:param variable_scope: string specifying variable scope under which to create network variables
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
:param args: all other arguments for class initializer
:param kwargs: all other keyword arguments for class initializer
:return: a GeneralTensorFlowNetwork object
"""
return GeneralMxnetNetwork(*args, devices=[GeneralMxnetNetwork._mx_device(d) for d in devices], **kwargs)
@staticmethod
def _mx_device(device: Union[str, Device]) -> mx.Context:
"""
Convert device to tensorflow-specific device representation
:param device: either a specific string (used in distributed mode) which is returned without
any change or a Device type
:return: tensorflow-specific string for device
"""
if isinstance(device, Device):
if device.device_type == DeviceType.CPU:
return mx.cpu()
elif device.device_type == DeviceType.GPU:
return mx.gpu(device.index)
else:
raise ValueError("Invalid device_type: {}".format(device.device_type))
else:
raise ValueError("Invalid device instance type: {}".format(type(device)))
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
devices: List[mx.Context],
name: str,
global_network=None,
network_is_local: bool=True,
@@ -55,6 +86,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
"""
:param agent_parameters: the agent parameters
:param spaces: the spaces definition of the agent
:param devices: list of devices to run the network on
:param name: the name of the network
:param global_network: the global network replica that is shared between all the workers
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
@@ -69,7 +101,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
self.num_heads_per_network = len(self.network_parameters.heads_parameters)
self.num_networks = 1
super().__init__(agent_parameters, spaces, name, global_network,
super().__init__(agent_parameters, spaces, devices, name, global_network,
network_is_local, network_is_trainable)
def construct_model(self):