1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-23 19:13:33 +01:00

Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)

This commit is contained in:
Sina Afrooze
2018-11-22 12:52:22 -08:00
committed by Gal Leibovich
parent 559969d3dd
commit 87a7848b0a
11 changed files with 219 additions and 91 deletions

View File

@@ -32,8 +32,8 @@ from rl_coach.utils import force_list, squeeze_list
class MxnetArchitecture(Architecture):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "",
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, devices: List[mx.Context],
name: str = "", global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
"""
:param agent_parameters: the agent parameters
:param spaces: the spaces definition of the agent
@@ -58,6 +58,7 @@ class MxnetArchitecture(Architecture):
self.network_is_trainable = network_is_trainable
self.is_training = False
self.model = None # type: GeneralModel
self._devices = devices
self.is_chief = self.ap.task_parameters.task_index == 0
self.network_is_global = not self.network_is_local and global_network is None
@@ -75,13 +76,17 @@ class MxnetArchitecture(Architecture):
def __str__(self):
return self.model.summary(*self._dummy_model_inputs())
@property
def _model_grads(self) -> Generator[NDArray, NDArray, Any]:
def _model_grads(self, index: int=0) ->\
Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]:
"""
Creates a copy of model gradients and returns them in a list, in the same order as collect_params()
:param index: device index. Set to -1 to get a tuple of list of NDArrays for all devices
:return: a generator for model gradient values
"""
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
if index < 0:
return (p.list_grad() for p in self.model.collect_params().values() if p.grad_req != 'null')
else:
return (p.list_grad()[index] for p in self.model.collect_params().values() if p.grad_req != 'null')
def _model_input_shapes(self) -> List[List[int]]:
"""
@@ -101,7 +106,7 @@ class MxnetArchitecture(Architecture):
:return: tuple of inputs for model forward pass
"""
input_shapes = self._model_input_shapes()
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
inputs = tuple(nd.zeros(tuple(shape), ctx=self._devices[0]) for shape in input_shapes)
return inputs
def construct_model(self) -> None:
@@ -117,9 +122,8 @@ class MxnetArchitecture(Architecture):
:param sess: must be None
"""
assert sess is None
# FIXME Add GPU initialization
# FIXME Add initializer
self.model.collect_params().initialize(ctx=mx.cpu())
self.model.collect_params().initialize(ctx=self._devices)
# Hybridize model and losses
self.model.hybridize()
for l in self.losses:
@@ -145,7 +149,7 @@ class MxnetArchitecture(Architecture):
for a in self.accumulated_gradients:
a *= 0
else:
self.accumulated_gradients = [g.copy() for g in self._model_grads]
self.accumulated_gradients = [g.copy() for g in self._model_grads()]
def accumulate_gradients(self,
inputs: Dict[str, np.ndarray],
@@ -175,7 +179,7 @@ class MxnetArchitecture(Architecture):
self.reset_accumulated_gradients()
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported"
@@ -190,7 +194,7 @@ class MxnetArchitecture(Architecture):
for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss):
l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss)
# Align arguments with loss.loss_forward and convert to NDArray
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss))
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss), h_out[0].context)
# Calculate loss and all auxiliary outputs
loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema)
if LOSS_OUT_TYPE_LOSS in loss_outputs:
@@ -216,25 +220,26 @@ class MxnetArchitecture(Architecture):
# allreduce gradients from all contexts
self.trainer.allreduce_grads()
model_grads_cpy = [g.copy() for g in self._model_grads()]
# Calculate global norm of gradients
# FIXME global norm is returned even when not used for clipping! Is this necessary?
# FIXME global norm might be calculated twice if clipping method is global norm
norm_unclipped_grads = utils.global_norm(self._model_grads)
norm_unclipped_grads = utils.global_norm(model_grads_cpy)
# Clip gradients
if self.network_parameters.clip_gradients:
utils.clip_grad(
self._model_grads,
model_grads_cpy,
clip_method=self.network_parameters.gradients_clipping_method,
clip_val=self.network_parameters.clip_gradients,
inplace=True)
# Update self.accumulated_gradients depending on no_accumulation flag
if no_accumulation:
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
acc_grad[:] = model_grad
else:
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
acc_grad += model_grad
# result of of additional fetches
@@ -269,8 +274,9 @@ class MxnetArchitecture(Architecture):
batch_size = self.ap.task_parameters.num_training_tasks
# set parameter gradients to gradients passed in
for param_grad, gradient in zip(self._model_grads, gradients):
param_grad[:] = gradient
for param_grad, gradient in zip(self._model_grads(-1), gradients):
for pg in param_grad:
pg[:] = gradient
# update gradients
self.trainer.update(batch_size=batch_size)
@@ -283,7 +289,7 @@ class MxnetArchitecture(Architecture):
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
"""
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
assert self.middleware.__class__.__name__ != 'LSTMMiddleware'

View File

@@ -37,7 +37,7 @@ from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, Ten
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
@@ -45,9 +45,40 @@ class GeneralMxnetNetwork(MxnetArchitecture):
"""
A generalized version of all possible networks implemented using mxnet.
"""
@staticmethod
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork':
"""
Construct a network class using the provided variable scope and on requested devices
:param variable_scope: string specifying variable scope under which to create network variables
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
:param args: all other arguments for class initializer
:param kwargs: all other keyword arguments for class initializer
:return: a GeneralTensorFlowNetwork object
"""
return GeneralMxnetNetwork(*args, devices=[GeneralMxnetNetwork._mx_device(d) for d in devices], **kwargs)
@staticmethod
def _mx_device(device: Union[str, Device]) -> mx.Context:
"""
Convert device to tensorflow-specific device representation
:param device: either a specific string (used in distributed mode) which is returned without
any change or a Device type
:return: tensorflow-specific string for device
"""
if isinstance(device, Device):
if device.device_type == DeviceType.CPU:
return mx.cpu()
elif device.device_type == DeviceType.GPU:
return mx.gpu(device.index)
else:
raise ValueError("Invalid device_type: {}".format(device.device_type))
else:
raise ValueError("Invalid device instance type: {}".format(type(device)))
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
devices: List[mx.Context],
name: str,
global_network=None,
network_is_local: bool=True,
@@ -55,6 +86,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
"""
:param agent_parameters: the agent parameters
:param spaces: the spaces definition of the agent
:param devices: list of devices to run the network on
:param name: the name of the network
:param global_network: the global network replica that is shared between all the workers
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
@@ -69,7 +101,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
self.num_heads_per_network = len(self.network_parameters.heads_parameters)
self.num_networks = 1
super().__init__(agent_parameters, spaces, name, global_network,
super().__init__(agent_parameters, spaces, devices, name, global_network,
network_is_local, network_is_trainable)
def construct_model(self):

View File

@@ -15,24 +15,26 @@ from rl_coach.core_types import GradientClippingMethod
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float], ctx: mx.Context=None) ->\
Union[List[NDArray], Tuple[NDArray], NDArray]:
"""
Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or
it can be np.ndarray, int, or float
:param data: input data to be converted
:param ctx: context of the data (CPU, GPU0, GPU1, etc.)
:return: converted output data
"""
if isinstance(data, list):
data = [to_mx_ndarray(d) for d in data]
data = [to_mx_ndarray(d, ctx=ctx) for d in data]
elif isinstance(data, tuple):
data = tuple(to_mx_ndarray(d) for d in data)
data = tuple(to_mx_ndarray(d, ctx=ctx) for d in data)
elif isinstance(data, np.ndarray):
data = nd.array(data)
data = nd.array(data, ctx=ctx)
elif isinstance(data, NDArray):
assert data.context == ctx
pass
elif isinstance(data, int) or isinstance(data, float):
data = nd.array([data])
data = nd.array([data], ctx=ctx)
else:
raise TypeError('Unsupported data type: {}'.format(type(data)))
return data