mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)
This commit is contained in:
committed by
Gal Leibovich
parent
559969d3dd
commit
87a7848b0a
@@ -32,8 +32,8 @@ from rl_coach.utils import force_list, squeeze_list
|
||||
|
||||
|
||||
class MxnetArchitecture(Architecture):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "",
|
||||
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, devices: List[mx.Context],
|
||||
name: str = "", global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
|
||||
"""
|
||||
:param agent_parameters: the agent parameters
|
||||
:param spaces: the spaces definition of the agent
|
||||
@@ -58,6 +58,7 @@ class MxnetArchitecture(Architecture):
|
||||
self.network_is_trainable = network_is_trainable
|
||||
self.is_training = False
|
||||
self.model = None # type: GeneralModel
|
||||
self._devices = devices
|
||||
|
||||
self.is_chief = self.ap.task_parameters.task_index == 0
|
||||
self.network_is_global = not self.network_is_local and global_network is None
|
||||
@@ -75,13 +76,17 @@ class MxnetArchitecture(Architecture):
|
||||
def __str__(self):
|
||||
return self.model.summary(*self._dummy_model_inputs())
|
||||
|
||||
@property
|
||||
def _model_grads(self) -> Generator[NDArray, NDArray, Any]:
|
||||
def _model_grads(self, index: int=0) ->\
|
||||
Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]:
|
||||
"""
|
||||
Creates a copy of model gradients and returns them in a list, in the same order as collect_params()
|
||||
:param index: device index. Set to -1 to get a tuple of list of NDArrays for all devices
|
||||
:return: a generator for model gradient values
|
||||
"""
|
||||
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||
if index < 0:
|
||||
return (p.list_grad() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||
else:
|
||||
return (p.list_grad()[index] for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||
|
||||
def _model_input_shapes(self) -> List[List[int]]:
|
||||
"""
|
||||
@@ -101,7 +106,7 @@ class MxnetArchitecture(Architecture):
|
||||
:return: tuple of inputs for model forward pass
|
||||
"""
|
||||
input_shapes = self._model_input_shapes()
|
||||
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
|
||||
inputs = tuple(nd.zeros(tuple(shape), ctx=self._devices[0]) for shape in input_shapes)
|
||||
return inputs
|
||||
|
||||
def construct_model(self) -> None:
|
||||
@@ -117,9 +122,8 @@ class MxnetArchitecture(Architecture):
|
||||
:param sess: must be None
|
||||
"""
|
||||
assert sess is None
|
||||
# FIXME Add GPU initialization
|
||||
# FIXME Add initializer
|
||||
self.model.collect_params().initialize(ctx=mx.cpu())
|
||||
self.model.collect_params().initialize(ctx=self._devices)
|
||||
# Hybridize model and losses
|
||||
self.model.hybridize()
|
||||
for l in self.losses:
|
||||
@@ -145,7 +149,7 @@ class MxnetArchitecture(Architecture):
|
||||
for a in self.accumulated_gradients:
|
||||
a *= 0
|
||||
else:
|
||||
self.accumulated_gradients = [g.copy() for g in self._model_grads]
|
||||
self.accumulated_gradients = [g.copy() for g in self._model_grads()]
|
||||
|
||||
def accumulate_gradients(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
@@ -175,7 +179,7 @@ class MxnetArchitecture(Architecture):
|
||||
self.reset_accumulated_gradients()
|
||||
|
||||
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
||||
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
|
||||
nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
|
||||
|
||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported"
|
||||
|
||||
@@ -190,7 +194,7 @@ class MxnetArchitecture(Architecture):
|
||||
for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss):
|
||||
l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss)
|
||||
# Align arguments with loss.loss_forward and convert to NDArray
|
||||
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss))
|
||||
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss), h_out[0].context)
|
||||
# Calculate loss and all auxiliary outputs
|
||||
loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema)
|
||||
if LOSS_OUT_TYPE_LOSS in loss_outputs:
|
||||
@@ -216,25 +220,26 @@ class MxnetArchitecture(Architecture):
|
||||
# allreduce gradients from all contexts
|
||||
self.trainer.allreduce_grads()
|
||||
|
||||
model_grads_cpy = [g.copy() for g in self._model_grads()]
|
||||
# Calculate global norm of gradients
|
||||
# FIXME global norm is returned even when not used for clipping! Is this necessary?
|
||||
# FIXME global norm might be calculated twice if clipping method is global norm
|
||||
norm_unclipped_grads = utils.global_norm(self._model_grads)
|
||||
norm_unclipped_grads = utils.global_norm(model_grads_cpy)
|
||||
|
||||
# Clip gradients
|
||||
if self.network_parameters.clip_gradients:
|
||||
utils.clip_grad(
|
||||
self._model_grads,
|
||||
model_grads_cpy,
|
||||
clip_method=self.network_parameters.gradients_clipping_method,
|
||||
clip_val=self.network_parameters.clip_gradients,
|
||||
inplace=True)
|
||||
|
||||
# Update self.accumulated_gradients depending on no_accumulation flag
|
||||
if no_accumulation:
|
||||
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
|
||||
for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
|
||||
acc_grad[:] = model_grad
|
||||
else:
|
||||
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
|
||||
for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
|
||||
acc_grad += model_grad
|
||||
|
||||
# result of of additional fetches
|
||||
@@ -269,8 +274,9 @@ class MxnetArchitecture(Architecture):
|
||||
batch_size = self.ap.task_parameters.num_training_tasks
|
||||
|
||||
# set parameter gradients to gradients passed in
|
||||
for param_grad, gradient in zip(self._model_grads, gradients):
|
||||
param_grad[:] = gradient
|
||||
for param_grad, gradient in zip(self._model_grads(-1), gradients):
|
||||
for pg in param_grad:
|
||||
pg[:] = gradient
|
||||
# update gradients
|
||||
self.trainer.update(batch_size=batch_size)
|
||||
|
||||
@@ -283,7 +289,7 @@ class MxnetArchitecture(Architecture):
|
||||
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
||||
"""
|
||||
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
||||
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
|
||||
nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
|
||||
|
||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user