1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)

This commit is contained in:
Sina Afrooze
2018-11-22 12:52:22 -08:00
committed by Gal Leibovich
parent 559969d3dd
commit 87a7848b0a
11 changed files with 219 additions and 91 deletions

View File

@@ -26,7 +26,7 @@ from six.moves import range
from rl_coach.agents.agent_interface import AgentInterface from rl_coach.agents.agent_interface import AgentInterface
from rl_coach.architectures.network_wrapper import NetworkWrapper from rl_coach.architectures.network_wrapper import NetworkWrapper
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters, Frameworks from rl_coach.base_parameters import AgentParameters, Device, DeviceType, DistributedTaskParameters, Frameworks
from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
from rl_coach.logger import screen, Logger, EpisodeLogger from rl_coach.logger import screen, Logger, EpisodeLogger
@@ -98,14 +98,18 @@ class Agent(AgentInterface):
self.has_global = True self.has_global = True
self.replicated_device = agent_parameters.task_parameters.device self.replicated_device = agent_parameters.task_parameters.device
self.worker_device = "/job:worker/task:{}".format(self.task_id) self.worker_device = "/job:worker/task:{}".format(self.task_id)
else:
self.has_global = False
self.replicated_device = None
self.worker_device = ""
if agent_parameters.task_parameters.use_cpu: if agent_parameters.task_parameters.use_cpu:
self.worker_device += "/cpu:0" self.worker_device += "/cpu:0"
else: else:
self.worker_device += "/device:GPU:0" self.worker_device += "/device:GPU:0"
else:
self.has_global = False
self.replicated_device = None
if agent_parameters.task_parameters.use_cpu:
self.worker_device = Device(DeviceType.CPU)
else:
self.worker_device = [Device(DeviceType.GPU, i)
for i in range(agent_parameters.task_parameters.num_gpu)]
# filters # filters
self.input_filter = self.ap.input_filter self.input_filter = self.ap.input_filter

View File

@@ -24,6 +24,18 @@ from rl_coach.spaces import SpacesDefinition
class Architecture(object): class Architecture(object):
@staticmethod
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'Architecture':
"""
Construct a network class using the provided variable scope and on requested devices
:param variable_scope: string specifying variable scope under which to create network variables
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
:param args: all other arguments for class initializer
:param kwargs: all other keyword arguments for class initializer
:return: an object which is a child of Architecture
"""
raise NotImplementedError
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= ""): def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= ""):
""" """
Creates a neural network 'architecture', that can be trained and used for inference. Creates a neural network 'architecture', that can be trained and used for inference.

View File

@@ -32,8 +32,8 @@ from rl_coach.utils import force_list, squeeze_list
class MxnetArchitecture(Architecture): class MxnetArchitecture(Architecture):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "", def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, devices: List[mx.Context],
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): name: str = "", global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
""" """
:param agent_parameters: the agent parameters :param agent_parameters: the agent parameters
:param spaces: the spaces definition of the agent :param spaces: the spaces definition of the agent
@@ -58,6 +58,7 @@ class MxnetArchitecture(Architecture):
self.network_is_trainable = network_is_trainable self.network_is_trainable = network_is_trainable
self.is_training = False self.is_training = False
self.model = None # type: GeneralModel self.model = None # type: GeneralModel
self._devices = devices
self.is_chief = self.ap.task_parameters.task_index == 0 self.is_chief = self.ap.task_parameters.task_index == 0
self.network_is_global = not self.network_is_local and global_network is None self.network_is_global = not self.network_is_local and global_network is None
@@ -75,13 +76,17 @@ class MxnetArchitecture(Architecture):
def __str__(self): def __str__(self):
return self.model.summary(*self._dummy_model_inputs()) return self.model.summary(*self._dummy_model_inputs())
@property def _model_grads(self, index: int=0) ->\
def _model_grads(self) -> Generator[NDArray, NDArray, Any]: Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]:
""" """
Creates a copy of model gradients and returns them in a list, in the same order as collect_params() Creates a copy of model gradients and returns them in a list, in the same order as collect_params()
:param index: device index. Set to -1 to get a tuple of list of NDArrays for all devices
:return: a generator for model gradient values :return: a generator for model gradient values
""" """
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null') if index < 0:
return (p.list_grad() for p in self.model.collect_params().values() if p.grad_req != 'null')
else:
return (p.list_grad()[index] for p in self.model.collect_params().values() if p.grad_req != 'null')
def _model_input_shapes(self) -> List[List[int]]: def _model_input_shapes(self) -> List[List[int]]:
""" """
@@ -101,7 +106,7 @@ class MxnetArchitecture(Architecture):
:return: tuple of inputs for model forward pass :return: tuple of inputs for model forward pass
""" """
input_shapes = self._model_input_shapes() input_shapes = self._model_input_shapes()
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes) inputs = tuple(nd.zeros(tuple(shape), ctx=self._devices[0]) for shape in input_shapes)
return inputs return inputs
def construct_model(self) -> None: def construct_model(self) -> None:
@@ -117,9 +122,8 @@ class MxnetArchitecture(Architecture):
:param sess: must be None :param sess: must be None
""" """
assert sess is None assert sess is None
# FIXME Add GPU initialization
# FIXME Add initializer # FIXME Add initializer
self.model.collect_params().initialize(ctx=mx.cpu()) self.model.collect_params().initialize(ctx=self._devices)
# Hybridize model and losses # Hybridize model and losses
self.model.hybridize() self.model.hybridize()
for l in self.losses: for l in self.losses:
@@ -145,7 +149,7 @@ class MxnetArchitecture(Architecture):
for a in self.accumulated_gradients: for a in self.accumulated_gradients:
a *= 0 a *= 0
else: else:
self.accumulated_gradients = [g.copy() for g in self._model_grads] self.accumulated_gradients = [g.copy() for g in self._model_grads()]
def accumulate_gradients(self, def accumulate_gradients(self,
inputs: Dict[str, np.ndarray], inputs: Dict[str, np.ndarray],
@@ -175,7 +179,7 @@ class MxnetArchitecture(Architecture):
self.reset_accumulated_gradients() self.reset_accumulated_gradients()
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders] embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders) nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported" assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported"
@@ -190,7 +194,7 @@ class MxnetArchitecture(Architecture):
for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss): for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss):
l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss) l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss)
# Align arguments with loss.loss_forward and convert to NDArray # Align arguments with loss.loss_forward and convert to NDArray
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss)) l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss), h_out[0].context)
# Calculate loss and all auxiliary outputs # Calculate loss and all auxiliary outputs
loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema) loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema)
if LOSS_OUT_TYPE_LOSS in loss_outputs: if LOSS_OUT_TYPE_LOSS in loss_outputs:
@@ -216,25 +220,26 @@ class MxnetArchitecture(Architecture):
# allreduce gradients from all contexts # allreduce gradients from all contexts
self.trainer.allreduce_grads() self.trainer.allreduce_grads()
model_grads_cpy = [g.copy() for g in self._model_grads()]
# Calculate global norm of gradients # Calculate global norm of gradients
# FIXME global norm is returned even when not used for clipping! Is this necessary? # FIXME global norm is returned even when not used for clipping! Is this necessary?
# FIXME global norm might be calculated twice if clipping method is global norm # FIXME global norm might be calculated twice if clipping method is global norm
norm_unclipped_grads = utils.global_norm(self._model_grads) norm_unclipped_grads = utils.global_norm(model_grads_cpy)
# Clip gradients # Clip gradients
if self.network_parameters.clip_gradients: if self.network_parameters.clip_gradients:
utils.clip_grad( utils.clip_grad(
self._model_grads, model_grads_cpy,
clip_method=self.network_parameters.gradients_clipping_method, clip_method=self.network_parameters.gradients_clipping_method,
clip_val=self.network_parameters.clip_gradients, clip_val=self.network_parameters.clip_gradients,
inplace=True) inplace=True)
# Update self.accumulated_gradients depending on no_accumulation flag # Update self.accumulated_gradients depending on no_accumulation flag
if no_accumulation: if no_accumulation:
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads): for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
acc_grad[:] = model_grad acc_grad[:] = model_grad
else: else:
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads): for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
acc_grad += model_grad acc_grad += model_grad
# result of of additional fetches # result of of additional fetches
@@ -269,8 +274,9 @@ class MxnetArchitecture(Architecture):
batch_size = self.ap.task_parameters.num_training_tasks batch_size = self.ap.task_parameters.num_training_tasks
# set parameter gradients to gradients passed in # set parameter gradients to gradients passed in
for param_grad, gradient in zip(self._model_grads, gradients): for param_grad, gradient in zip(self._model_grads(-1), gradients):
param_grad[:] = gradient for pg in param_grad:
pg[:] = gradient
# update gradients # update gradients
self.trainer.update(batch_size=batch_size) self.trainer.update(batch_size=batch_size)
@@ -283,7 +289,7 @@ class MxnetArchitecture(Architecture):
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step. WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
""" """
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders] embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders) nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
assert self.middleware.__class__.__name__ != 'LSTMMiddleware' assert self.middleware.__class__.__name__ != 'LSTMMiddleware'

View File

@@ -37,7 +37,7 @@ from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, Ten
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
from rl_coach.architectures.mxnet_components import utils from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
@@ -45,9 +45,40 @@ class GeneralMxnetNetwork(MxnetArchitecture):
""" """
A generalized version of all possible networks implemented using mxnet. A generalized version of all possible networks implemented using mxnet.
""" """
@staticmethod
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork':
"""
Construct a network class using the provided variable scope and on requested devices
:param variable_scope: string specifying variable scope under which to create network variables
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
:param args: all other arguments for class initializer
:param kwargs: all other keyword arguments for class initializer
:return: a GeneralTensorFlowNetwork object
"""
return GeneralMxnetNetwork(*args, devices=[GeneralMxnetNetwork._mx_device(d) for d in devices], **kwargs)
@staticmethod
def _mx_device(device: Union[str, Device]) -> mx.Context:
"""
Convert device to tensorflow-specific device representation
:param device: either a specific string (used in distributed mode) which is returned without
any change or a Device type
:return: tensorflow-specific string for device
"""
if isinstance(device, Device):
if device.device_type == DeviceType.CPU:
return mx.cpu()
elif device.device_type == DeviceType.GPU:
return mx.gpu(device.index)
else:
raise ValueError("Invalid device_type: {}".format(device.device_type))
else:
raise ValueError("Invalid device instance type: {}".format(type(device)))
def __init__(self, def __init__(self,
agent_parameters: AgentParameters, agent_parameters: AgentParameters,
spaces: SpacesDefinition, spaces: SpacesDefinition,
devices: List[mx.Context],
name: str, name: str,
global_network=None, global_network=None,
network_is_local: bool=True, network_is_local: bool=True,
@@ -55,6 +86,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
""" """
:param agent_parameters: the agent parameters :param agent_parameters: the agent parameters
:param spaces: the spaces definition of the agent :param spaces: the spaces definition of the agent
:param devices: list of devices to run the network on
:param name: the name of the network :param name: the name of the network
:param global_network: the global network replica that is shared between all the workers :param global_network: the global network replica that is shared between all the workers
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker) :param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
@@ -69,7 +101,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
self.num_heads_per_network = len(self.network_parameters.heads_parameters) self.num_heads_per_network = len(self.network_parameters.heads_parameters)
self.num_networks = 1 self.num_networks = 1
super().__init__(agent_parameters, spaces, name, global_network, super().__init__(agent_parameters, spaces, devices, name, global_network,
network_is_local, network_is_trainable) network_is_local, network_is_trainable)
def construct_model(self): def construct_model(self):

View File

@@ -15,24 +15,26 @@ from rl_coach.core_types import GradientClippingMethod
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\ def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float], ctx: mx.Context=None) ->\
Union[List[NDArray], Tuple[NDArray], NDArray]: Union[List[NDArray], Tuple[NDArray], NDArray]:
""" """
Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or
it can be np.ndarray, int, or float it can be np.ndarray, int, or float
:param data: input data to be converted :param data: input data to be converted
:param ctx: context of the data (CPU, GPU0, GPU1, etc.)
:return: converted output data :return: converted output data
""" """
if isinstance(data, list): if isinstance(data, list):
data = [to_mx_ndarray(d) for d in data] data = [to_mx_ndarray(d, ctx=ctx) for d in data]
elif isinstance(data, tuple): elif isinstance(data, tuple):
data = tuple(to_mx_ndarray(d) for d in data) data = tuple(to_mx_ndarray(d, ctx=ctx) for d in data)
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
data = nd.array(data) data = nd.array(data, ctx=ctx)
elif isinstance(data, NDArray): elif isinstance(data, NDArray):
assert data.context == ctx
pass pass
elif isinstance(data, int) or isinstance(data, float): elif isinstance(data, int) or isinstance(data, float):
data = nd.array([data]) data = nd.array([data], ctx=ctx)
else: else:
raise TypeError('Unsupported data type: {}'.format(type(data))) raise TypeError('Unsupported data type: {}'.format(type(data)))
return data return data

View File

@@ -20,6 +20,7 @@ from rl_coach.base_parameters import Frameworks, AgentParameters
from rl_coach.logger import failed_imports from rl_coach.logger import failed_imports
from rl_coach.saver import SaverCollection from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list
try: try:
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
@@ -53,26 +54,27 @@ class NetworkWrapper(object):
if self.network_parameters.framework == Frameworks.tensorflow: if self.network_parameters.framework == Frameworks.tensorflow:
if "tensorflow" not in failed_imports: if "tensorflow" not in failed_imports:
general_network = GeneralTensorFlowNetwork general_network = GeneralTensorFlowNetwork.construct
else: else:
raise Exception('Install tensorflow before using it as framework') raise Exception('Install tensorflow before using it as framework')
elif self.network_parameters.framework == Frameworks.mxnet: elif self.network_parameters.framework == Frameworks.mxnet:
if "mxnet" not in failed_imports: if "mxnet" not in failed_imports:
general_network = GeneralMxnetNetwork general_network = GeneralMxnetNetwork.construct
else: else:
raise Exception('Install mxnet before using it as framework') raise Exception('Install mxnet before using it as framework')
else: else:
raise Exception("{} Framework is not supported" raise Exception("{} Framework is not supported"
.format(Frameworks().to_string(self.network_parameters.framework))) .format(Frameworks().to_string(self.network_parameters.framework)))
with tf.variable_scope("{}/{}".format(self.ap.full_name_id, name)): variable_scope = "{}/{}".format(self.ap.full_name_id, name)
# Global network - the main network shared between threads # Global network - the main network shared between threads
self.global_network = None self.global_network = None
if self.has_global: if self.has_global:
# we assign the parameters of this network on the parameters server # we assign the parameters of this network on the parameters server
with tf.device(replicated_device): self.global_network = general_network(variable_scope=variable_scope,
self.global_network = general_network(agent_parameters=agent_parameters, devices=force_list(replicated_device),
agent_parameters=agent_parameters,
name='{}/global'.format(name), name='{}/global'.format(name),
global_network=None, global_network=None,
network_is_local=False, network_is_local=False,
@@ -81,8 +83,9 @@ class NetworkWrapper(object):
# Online network - local copy of the main network used for playing # Online network - local copy of the main network used for playing
self.online_network = None self.online_network = None
with tf.device(worker_device): self.online_network = general_network(variable_scope=variable_scope,
self.online_network = general_network(agent_parameters=agent_parameters, devices=force_list(worker_device),
agent_parameters=agent_parameters,
name='{}/online'.format(name), name='{}/online'.format(name),
global_network=self.global_network, global_network=self.global_network,
network_is_local=True, network_is_local=True,
@@ -92,8 +95,9 @@ class NetworkWrapper(object):
# Target network - a local, slow updating network used for stabilizing the learning # Target network - a local, slow updating network used for stabilizing the learning
self.target_network = None self.target_network = None
if self.has_target: if self.has_target:
with tf.device(worker_device): self.target_network = general_network(variable_scope=variable_scope,
self.target_network = general_network(agent_parameters=agent_parameters, devices=force_list(worker_device),
agent_parameters=agent_parameters,
name='{}/target'.format(name), name='{}/target'.format(name),
global_network=self.global_network, global_network=self.global_network,
network_is_local=True, network_is_local=True,
@@ -198,26 +202,6 @@ class NetworkWrapper(object):
""" """
return type(self.online_network).parallel_predict(self.sess, network_input_tuples) return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
def get_local_variables(self):
"""
Get all the variables that are local to the thread
:return: a list of all the variables that are local to the thread
"""
local_variables = [v for v in tf.local_variables() if self.online_network.name in v.name]
if self.has_target:
local_variables += [v for v in tf.local_variables() if self.target_network.name in v.name]
return local_variables
def get_global_variables(self):
"""
Get all the variables that are shared between threads
:return: a list of all the variables that are shared between threads
"""
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
return global_variables
def set_is_training(self, state: bool): def set_is_training(self, state: bool):
""" """
Set the phase of the network between training and testing Set the phase of the network between training and testing

View File

@@ -15,7 +15,7 @@
# #
import copy import copy
from typing import Dict from typing import Dict, List, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@@ -25,8 +25,9 @@ from rl_coach.architectures.head_parameters import HeadParameters
from rl_coach.architectures.middleware_parameters import MiddlewareParameters from rl_coach.architectures.middleware_parameters import MiddlewareParameters
from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture
from rl_coach.architectures.tensorflow_components import utils from rl_coach.architectures.tensorflow_components import utils
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType
from rl_coach.core_types import PredictionType from rl_coach.core_types import PredictionType
from rl_coach.logger import screen
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
@@ -35,6 +36,62 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
""" """
A generalized version of all possible networks implemented using tensorflow. A generalized version of all possible networks implemented using tensorflow.
""" """
# dictionary of variable-scope name to variable-scope object to prevent tensorflow from
# creating a new auxiliary variable scope even when name is properly specified
variable_scopes_dict = dict()
@staticmethod
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork':
"""
Construct a network class using the provided variable scope and on requested devices
:param variable_scope: string specifying variable scope under which to create network variables
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
:param args: all other arguments for class initializer
:param kwargs: all other keyword arguments for class initializer
:return: a GeneralTensorFlowNetwork object
"""
if len(devices) > 1:
screen.warning("Tensorflow implementation only support a single device. Using {}".format(devices[0]))
def construct_on_device():
with tf.device(GeneralTensorFlowNetwork._tf_device(devices[0])):
return GeneralTensorFlowNetwork(*args, **kwargs)
# If variable_scope is in our dictionary, then this is not the first time that this variable_scope
# is being used with construct(). So to avoid TF adding an incrementing number to the end of the
# variable_scope to uniquify it, we have to both pass the previous variable_scope object to the new
# variable_scope() call and also recover the name space using name_scope
if variable_scope in GeneralTensorFlowNetwork.variable_scopes_dict:
variable_scope = GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope]
with tf.variable_scope(variable_scope, auxiliary_name_scope=False) as vs:
with tf.name_scope(vs.original_name_scope):
return construct_on_device()
else:
with tf.variable_scope(variable_scope, auxiliary_name_scope=True) as vs:
# Add variable_scope object to dictionary for next call to construct
GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] = vs
return construct_on_device()
@staticmethod
def _tf_device(device: Union[str, Device]) -> str:
"""
Convert device to tensorflow-specific device representation
:param device: either a specific string (used in distributed mode) which is returned without
any change or a Device type
:return: tensorflow-specific string for device
"""
if isinstance(device, str):
return device
elif isinstance(device, Device):
if device.device_type == DeviceType.CPU:
return "/cpu:0"
elif device.device_type == DeviceType.GPU:
return "/device:GPU:{}".format(device.index)
else:
raise ValueError("Invalid device_type: {}".format(device.device_type))
else:
raise ValueError("Invalid device instance type: {}".format(type(device)))
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str,
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
""" """

View File

@@ -61,6 +61,35 @@ class RunType(Enum):
return self.value return self.value
class DeviceType(Enum):
CPU = 'cpu'
GPU = 'gpu'
class Device(object):
def __init__(self, device_type: DeviceType, index: int=0):
"""
:param device_type: type of device (CPU/GPU)
:param index: index of device (only used if device type is GPU)
"""
self._device_type = device_type
self._index = index
@property
def device_type(self):
return self._device_type
@property
def index(self):
return self._index
def __str__(self):
return "{}{}".format(self._device_type, self._index)
def __repr__(self):
return str(self)
# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach. # DistributedCoachSynchronizationType provides the synchronization type for distributed Coach.
# The default value is None, which means the algorithm or preset cannot be used with distributed Coach. # The default value is None, which means the algorithm or preset cannot be used with distributed Coach.
class DistributedCoachSynchronizationType(Enum): class DistributedCoachSynchronizationType(Enum):
@@ -520,7 +549,8 @@ class AgentParameters(Parameters):
class TaskParameters(Parameters): class TaskParameters(Parameters):
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False, def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False): checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False,
num_gpu: int=1):
""" """
:param framework_type: deep learning framework type. currently only tensorflow is supported :param framework_type: deep learning framework type. currently only tensorflow is supported
:param evaluate_only: the task will be used only for evaluating the model :param evaluate_only: the task will be used only for evaluating the model
@@ -532,7 +562,7 @@ class TaskParameters(Parameters):
:param checkpoint_save_dir: the directory to store the checkpoints in :param checkpoint_save_dir: the directory to store the checkpoints in
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
:param num_gpu: number of GPUs to use
""" """
self.framework_type = framework_type self.framework_type = framework_type
self.task_index = 0 # TODO: not really needed self.task_index = 0 # TODO: not really needed
@@ -545,6 +575,7 @@ class TaskParameters(Parameters):
self.seed = seed self.seed = seed
self.export_onnx_graph = export_onnx_graph self.export_onnx_graph = export_onnx_graph
self.apply_stop_condition = apply_stop_condition self.apply_stop_condition = apply_stop_condition
self.num_gpu = num_gpu
class DistributedTaskParameters(TaskParameters): class DistributedTaskParameters(TaskParameters):

View File

@@ -19,7 +19,7 @@ from typing import List, Dict
import numpy as np import numpy as np
from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.architectures.tensorflow_components.layers import NoisyNetDense from rl_coach.architectures.layers import NoisyNetDense
from rl_coach.base_parameters import AgentParameters, NetworkParameters from rl_coach.base_parameters import AgentParameters, NetworkParameters
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace

View File

@@ -19,7 +19,6 @@ from typing import List
import numpy as np import numpy as np
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats, TFSharedRunningStats
from rl_coach.core_types import ObservationType from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace from rl_coach.spaces import ObservationSpace
@@ -54,6 +53,7 @@ class ObservationNormalizationFilter(ObservationFilter):
:return: None :return: None
""" """
if mode == 'tf': if mode == 'tf':
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False, self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False,
pubsub_params=memory_backend_params) pubsub_params=memory_backend_params)
elif mode == 'numpy': elif mode == 'numpy':

View File

@@ -17,7 +17,6 @@ import os
import numpy as np import numpy as np
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
from rl_coach.core_types import RewardType from rl_coach.core_types import RewardType
from rl_coach.filters.reward.reward_filter import RewardFilter from rl_coach.filters.reward.reward_filter import RewardFilter
from rl_coach.spaces import RewardSpace from rl_coach.spaces import RewardSpace
@@ -48,6 +47,7 @@ class RewardNormalizationFilter(RewardFilter):
""" """
if mode == 'tf': if mode == 'tf':
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False, self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False,
pubsub_params=memory_backend_params) pubsub_params=memory_backend_params)
elif mode == 'numpy': elif mode == 'numpy':