mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)
This commit is contained in:
committed by
Gal Leibovich
parent
559969d3dd
commit
87a7848b0a
@@ -26,7 +26,7 @@ from six.moves import range
|
|||||||
|
|
||||||
from rl_coach.agents.agent_interface import AgentInterface
|
from rl_coach.agents.agent_interface import AgentInterface
|
||||||
from rl_coach.architectures.network_wrapper import NetworkWrapper
|
from rl_coach.architectures.network_wrapper import NetworkWrapper
|
||||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters, Frameworks
|
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, DistributedTaskParameters, Frameworks
|
||||||
from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType
|
from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType
|
||||||
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
|
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
|
||||||
from rl_coach.logger import screen, Logger, EpisodeLogger
|
from rl_coach.logger import screen, Logger, EpisodeLogger
|
||||||
@@ -98,14 +98,18 @@ class Agent(AgentInterface):
|
|||||||
self.has_global = True
|
self.has_global = True
|
||||||
self.replicated_device = agent_parameters.task_parameters.device
|
self.replicated_device = agent_parameters.task_parameters.device
|
||||||
self.worker_device = "/job:worker/task:{}".format(self.task_id)
|
self.worker_device = "/job:worker/task:{}".format(self.task_id)
|
||||||
|
if agent_parameters.task_parameters.use_cpu:
|
||||||
|
self.worker_device += "/cpu:0"
|
||||||
|
else:
|
||||||
|
self.worker_device += "/device:GPU:0"
|
||||||
else:
|
else:
|
||||||
self.has_global = False
|
self.has_global = False
|
||||||
self.replicated_device = None
|
self.replicated_device = None
|
||||||
self.worker_device = ""
|
if agent_parameters.task_parameters.use_cpu:
|
||||||
if agent_parameters.task_parameters.use_cpu:
|
self.worker_device = Device(DeviceType.CPU)
|
||||||
self.worker_device += "/cpu:0"
|
else:
|
||||||
else:
|
self.worker_device = [Device(DeviceType.GPU, i)
|
||||||
self.worker_device += "/device:GPU:0"
|
for i in range(agent_parameters.task_parameters.num_gpu)]
|
||||||
|
|
||||||
# filters
|
# filters
|
||||||
self.input_filter = self.ap.input_filter
|
self.input_filter = self.ap.input_filter
|
||||||
|
|||||||
@@ -24,6 +24,18 @@ from rl_coach.spaces import SpacesDefinition
|
|||||||
|
|
||||||
|
|
||||||
class Architecture(object):
|
class Architecture(object):
|
||||||
|
@staticmethod
|
||||||
|
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'Architecture':
|
||||||
|
"""
|
||||||
|
Construct a network class using the provided variable scope and on requested devices
|
||||||
|
:param variable_scope: string specifying variable scope under which to create network variables
|
||||||
|
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
|
||||||
|
:param args: all other arguments for class initializer
|
||||||
|
:param kwargs: all other keyword arguments for class initializer
|
||||||
|
:return: an object which is a child of Architecture
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= ""):
|
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= ""):
|
||||||
"""
|
"""
|
||||||
Creates a neural network 'architecture', that can be trained and used for inference.
|
Creates a neural network 'architecture', that can be trained and used for inference.
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ from rl_coach.utils import force_list, squeeze_list
|
|||||||
|
|
||||||
|
|
||||||
class MxnetArchitecture(Architecture):
|
class MxnetArchitecture(Architecture):
|
||||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "",
|
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, devices: List[mx.Context],
|
||||||
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
|
name: str = "", global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
|
||||||
"""
|
"""
|
||||||
:param agent_parameters: the agent parameters
|
:param agent_parameters: the agent parameters
|
||||||
:param spaces: the spaces definition of the agent
|
:param spaces: the spaces definition of the agent
|
||||||
@@ -58,6 +58,7 @@ class MxnetArchitecture(Architecture):
|
|||||||
self.network_is_trainable = network_is_trainable
|
self.network_is_trainable = network_is_trainable
|
||||||
self.is_training = False
|
self.is_training = False
|
||||||
self.model = None # type: GeneralModel
|
self.model = None # type: GeneralModel
|
||||||
|
self._devices = devices
|
||||||
|
|
||||||
self.is_chief = self.ap.task_parameters.task_index == 0
|
self.is_chief = self.ap.task_parameters.task_index == 0
|
||||||
self.network_is_global = not self.network_is_local and global_network is None
|
self.network_is_global = not self.network_is_local and global_network is None
|
||||||
@@ -75,13 +76,17 @@ class MxnetArchitecture(Architecture):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.model.summary(*self._dummy_model_inputs())
|
return self.model.summary(*self._dummy_model_inputs())
|
||||||
|
|
||||||
@property
|
def _model_grads(self, index: int=0) ->\
|
||||||
def _model_grads(self) -> Generator[NDArray, NDArray, Any]:
|
Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]:
|
||||||
"""
|
"""
|
||||||
Creates a copy of model gradients and returns them in a list, in the same order as collect_params()
|
Creates a copy of model gradients and returns them in a list, in the same order as collect_params()
|
||||||
|
:param index: device index. Set to -1 to get a tuple of list of NDArrays for all devices
|
||||||
:return: a generator for model gradient values
|
:return: a generator for model gradient values
|
||||||
"""
|
"""
|
||||||
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
if index < 0:
|
||||||
|
return (p.list_grad() for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||||
|
else:
|
||||||
|
return (p.list_grad()[index] for p in self.model.collect_params().values() if p.grad_req != 'null')
|
||||||
|
|
||||||
def _model_input_shapes(self) -> List[List[int]]:
|
def _model_input_shapes(self) -> List[List[int]]:
|
||||||
"""
|
"""
|
||||||
@@ -101,7 +106,7 @@ class MxnetArchitecture(Architecture):
|
|||||||
:return: tuple of inputs for model forward pass
|
:return: tuple of inputs for model forward pass
|
||||||
"""
|
"""
|
||||||
input_shapes = self._model_input_shapes()
|
input_shapes = self._model_input_shapes()
|
||||||
inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes)
|
inputs = tuple(nd.zeros(tuple(shape), ctx=self._devices[0]) for shape in input_shapes)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def construct_model(self) -> None:
|
def construct_model(self) -> None:
|
||||||
@@ -117,9 +122,8 @@ class MxnetArchitecture(Architecture):
|
|||||||
:param sess: must be None
|
:param sess: must be None
|
||||||
"""
|
"""
|
||||||
assert sess is None
|
assert sess is None
|
||||||
# FIXME Add GPU initialization
|
|
||||||
# FIXME Add initializer
|
# FIXME Add initializer
|
||||||
self.model.collect_params().initialize(ctx=mx.cpu())
|
self.model.collect_params().initialize(ctx=self._devices)
|
||||||
# Hybridize model and losses
|
# Hybridize model and losses
|
||||||
self.model.hybridize()
|
self.model.hybridize()
|
||||||
for l in self.losses:
|
for l in self.losses:
|
||||||
@@ -145,7 +149,7 @@ class MxnetArchitecture(Architecture):
|
|||||||
for a in self.accumulated_gradients:
|
for a in self.accumulated_gradients:
|
||||||
a *= 0
|
a *= 0
|
||||||
else:
|
else:
|
||||||
self.accumulated_gradients = [g.copy() for g in self._model_grads]
|
self.accumulated_gradients = [g.copy() for g in self._model_grads()]
|
||||||
|
|
||||||
def accumulate_gradients(self,
|
def accumulate_gradients(self,
|
||||||
inputs: Dict[str, np.ndarray],
|
inputs: Dict[str, np.ndarray],
|
||||||
@@ -175,7 +179,7 @@ class MxnetArchitecture(Architecture):
|
|||||||
self.reset_accumulated_gradients()
|
self.reset_accumulated_gradients()
|
||||||
|
|
||||||
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
||||||
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
|
nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
|
||||||
|
|
||||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported"
|
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported"
|
||||||
|
|
||||||
@@ -190,7 +194,7 @@ class MxnetArchitecture(Architecture):
|
|||||||
for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss):
|
for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss):
|
||||||
l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss)
|
l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss)
|
||||||
# Align arguments with loss.loss_forward and convert to NDArray
|
# Align arguments with loss.loss_forward and convert to NDArray
|
||||||
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss))
|
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss), h_out[0].context)
|
||||||
# Calculate loss and all auxiliary outputs
|
# Calculate loss and all auxiliary outputs
|
||||||
loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema)
|
loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema)
|
||||||
if LOSS_OUT_TYPE_LOSS in loss_outputs:
|
if LOSS_OUT_TYPE_LOSS in loss_outputs:
|
||||||
@@ -216,25 +220,26 @@ class MxnetArchitecture(Architecture):
|
|||||||
# allreduce gradients from all contexts
|
# allreduce gradients from all contexts
|
||||||
self.trainer.allreduce_grads()
|
self.trainer.allreduce_grads()
|
||||||
|
|
||||||
|
model_grads_cpy = [g.copy() for g in self._model_grads()]
|
||||||
# Calculate global norm of gradients
|
# Calculate global norm of gradients
|
||||||
# FIXME global norm is returned even when not used for clipping! Is this necessary?
|
# FIXME global norm is returned even when not used for clipping! Is this necessary?
|
||||||
# FIXME global norm might be calculated twice if clipping method is global norm
|
# FIXME global norm might be calculated twice if clipping method is global norm
|
||||||
norm_unclipped_grads = utils.global_norm(self._model_grads)
|
norm_unclipped_grads = utils.global_norm(model_grads_cpy)
|
||||||
|
|
||||||
# Clip gradients
|
# Clip gradients
|
||||||
if self.network_parameters.clip_gradients:
|
if self.network_parameters.clip_gradients:
|
||||||
utils.clip_grad(
|
utils.clip_grad(
|
||||||
self._model_grads,
|
model_grads_cpy,
|
||||||
clip_method=self.network_parameters.gradients_clipping_method,
|
clip_method=self.network_parameters.gradients_clipping_method,
|
||||||
clip_val=self.network_parameters.clip_gradients,
|
clip_val=self.network_parameters.clip_gradients,
|
||||||
inplace=True)
|
inplace=True)
|
||||||
|
|
||||||
# Update self.accumulated_gradients depending on no_accumulation flag
|
# Update self.accumulated_gradients depending on no_accumulation flag
|
||||||
if no_accumulation:
|
if no_accumulation:
|
||||||
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
|
for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
|
||||||
acc_grad[:] = model_grad
|
acc_grad[:] = model_grad
|
||||||
else:
|
else:
|
||||||
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
|
for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy):
|
||||||
acc_grad += model_grad
|
acc_grad += model_grad
|
||||||
|
|
||||||
# result of of additional fetches
|
# result of of additional fetches
|
||||||
@@ -269,8 +274,9 @@ class MxnetArchitecture(Architecture):
|
|||||||
batch_size = self.ap.task_parameters.num_training_tasks
|
batch_size = self.ap.task_parameters.num_training_tasks
|
||||||
|
|
||||||
# set parameter gradients to gradients passed in
|
# set parameter gradients to gradients passed in
|
||||||
for param_grad, gradient in zip(self._model_grads, gradients):
|
for param_grad, gradient in zip(self._model_grads(-1), gradients):
|
||||||
param_grad[:] = gradient
|
for pg in param_grad:
|
||||||
|
pg[:] = gradient
|
||||||
# update gradients
|
# update gradients
|
||||||
self.trainer.update(batch_size=batch_size)
|
self.trainer.update(batch_size=batch_size)
|
||||||
|
|
||||||
@@ -283,7 +289,7 @@ class MxnetArchitecture(Architecture):
|
|||||||
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
||||||
"""
|
"""
|
||||||
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
|
||||||
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
|
nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders)
|
||||||
|
|
||||||
assert self.middleware.__class__.__name__ != 'LSTMMiddleware'
|
assert self.middleware.__class__.__name__ != 'LSTMMiddleware'
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, Ten
|
|||||||
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
|
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
|
||||||
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
|
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
|
||||||
from rl_coach.architectures.mxnet_components import utils
|
from rl_coach.architectures.mxnet_components import utils
|
||||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType
|
||||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
||||||
|
|
||||||
|
|
||||||
@@ -45,9 +45,40 @@ class GeneralMxnetNetwork(MxnetArchitecture):
|
|||||||
"""
|
"""
|
||||||
A generalized version of all possible networks implemented using mxnet.
|
A generalized version of all possible networks implemented using mxnet.
|
||||||
"""
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork':
|
||||||
|
"""
|
||||||
|
Construct a network class using the provided variable scope and on requested devices
|
||||||
|
:param variable_scope: string specifying variable scope under which to create network variables
|
||||||
|
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
|
||||||
|
:param args: all other arguments for class initializer
|
||||||
|
:param kwargs: all other keyword arguments for class initializer
|
||||||
|
:return: a GeneralTensorFlowNetwork object
|
||||||
|
"""
|
||||||
|
return GeneralMxnetNetwork(*args, devices=[GeneralMxnetNetwork._mx_device(d) for d in devices], **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _mx_device(device: Union[str, Device]) -> mx.Context:
|
||||||
|
"""
|
||||||
|
Convert device to tensorflow-specific device representation
|
||||||
|
:param device: either a specific string (used in distributed mode) which is returned without
|
||||||
|
any change or a Device type
|
||||||
|
:return: tensorflow-specific string for device
|
||||||
|
"""
|
||||||
|
if isinstance(device, Device):
|
||||||
|
if device.device_type == DeviceType.CPU:
|
||||||
|
return mx.cpu()
|
||||||
|
elif device.device_type == DeviceType.GPU:
|
||||||
|
return mx.gpu(device.index)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid device_type: {}".format(device.device_type))
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid device instance type: {}".format(type(device)))
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
agent_parameters: AgentParameters,
|
agent_parameters: AgentParameters,
|
||||||
spaces: SpacesDefinition,
|
spaces: SpacesDefinition,
|
||||||
|
devices: List[mx.Context],
|
||||||
name: str,
|
name: str,
|
||||||
global_network=None,
|
global_network=None,
|
||||||
network_is_local: bool=True,
|
network_is_local: bool=True,
|
||||||
@@ -55,6 +86,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
|
|||||||
"""
|
"""
|
||||||
:param agent_parameters: the agent parameters
|
:param agent_parameters: the agent parameters
|
||||||
:param spaces: the spaces definition of the agent
|
:param spaces: the spaces definition of the agent
|
||||||
|
:param devices: list of devices to run the network on
|
||||||
:param name: the name of the network
|
:param name: the name of the network
|
||||||
:param global_network: the global network replica that is shared between all the workers
|
:param global_network: the global network replica that is shared between all the workers
|
||||||
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
|
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
|
||||||
@@ -69,7 +101,7 @@ class GeneralMxnetNetwork(MxnetArchitecture):
|
|||||||
self.num_heads_per_network = len(self.network_parameters.heads_parameters)
|
self.num_heads_per_network = len(self.network_parameters.heads_parameters)
|
||||||
self.num_networks = 1
|
self.num_networks = 1
|
||||||
|
|
||||||
super().__init__(agent_parameters, spaces, name, global_network,
|
super().__init__(agent_parameters, spaces, devices, name, global_network,
|
||||||
network_is_local, network_is_trainable)
|
network_is_local, network_is_trainable)
|
||||||
|
|
||||||
def construct_model(self):
|
def construct_model(self):
|
||||||
|
|||||||
@@ -15,24 +15,26 @@ from rl_coach.core_types import GradientClippingMethod
|
|||||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||||
|
|
||||||
|
|
||||||
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\
|
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float], ctx: mx.Context=None) ->\
|
||||||
Union[List[NDArray], Tuple[NDArray], NDArray]:
|
Union[List[NDArray], Tuple[NDArray], NDArray]:
|
||||||
"""
|
"""
|
||||||
Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or
|
Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or
|
||||||
it can be np.ndarray, int, or float
|
it can be np.ndarray, int, or float
|
||||||
:param data: input data to be converted
|
:param data: input data to be converted
|
||||||
|
:param ctx: context of the data (CPU, GPU0, GPU1, etc.)
|
||||||
:return: converted output data
|
:return: converted output data
|
||||||
"""
|
"""
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
data = [to_mx_ndarray(d) for d in data]
|
data = [to_mx_ndarray(d, ctx=ctx) for d in data]
|
||||||
elif isinstance(data, tuple):
|
elif isinstance(data, tuple):
|
||||||
data = tuple(to_mx_ndarray(d) for d in data)
|
data = tuple(to_mx_ndarray(d, ctx=ctx) for d in data)
|
||||||
elif isinstance(data, np.ndarray):
|
elif isinstance(data, np.ndarray):
|
||||||
data = nd.array(data)
|
data = nd.array(data, ctx=ctx)
|
||||||
elif isinstance(data, NDArray):
|
elif isinstance(data, NDArray):
|
||||||
|
assert data.context == ctx
|
||||||
pass
|
pass
|
||||||
elif isinstance(data, int) or isinstance(data, float):
|
elif isinstance(data, int) or isinstance(data, float):
|
||||||
data = nd.array([data])
|
data = nd.array([data], ctx=ctx)
|
||||||
else:
|
else:
|
||||||
raise TypeError('Unsupported data type: {}'.format(type(data)))
|
raise TypeError('Unsupported data type: {}'.format(type(data)))
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from rl_coach.base_parameters import Frameworks, AgentParameters
|
|||||||
from rl_coach.logger import failed_imports
|
from rl_coach.logger import failed_imports
|
||||||
from rl_coach.saver import SaverCollection
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import SpacesDefinition
|
from rl_coach.spaces import SpacesDefinition
|
||||||
|
from rl_coach.utils import force_list
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
||||||
@@ -53,52 +54,55 @@ class NetworkWrapper(object):
|
|||||||
|
|
||||||
if self.network_parameters.framework == Frameworks.tensorflow:
|
if self.network_parameters.framework == Frameworks.tensorflow:
|
||||||
if "tensorflow" not in failed_imports:
|
if "tensorflow" not in failed_imports:
|
||||||
general_network = GeneralTensorFlowNetwork
|
general_network = GeneralTensorFlowNetwork.construct
|
||||||
else:
|
else:
|
||||||
raise Exception('Install tensorflow before using it as framework')
|
raise Exception('Install tensorflow before using it as framework')
|
||||||
elif self.network_parameters.framework == Frameworks.mxnet:
|
elif self.network_parameters.framework == Frameworks.mxnet:
|
||||||
if "mxnet" not in failed_imports:
|
if "mxnet" not in failed_imports:
|
||||||
general_network = GeneralMxnetNetwork
|
general_network = GeneralMxnetNetwork.construct
|
||||||
else:
|
else:
|
||||||
raise Exception('Install mxnet before using it as framework')
|
raise Exception('Install mxnet before using it as framework')
|
||||||
else:
|
else:
|
||||||
raise Exception("{} Framework is not supported"
|
raise Exception("{} Framework is not supported"
|
||||||
.format(Frameworks().to_string(self.network_parameters.framework)))
|
.format(Frameworks().to_string(self.network_parameters.framework)))
|
||||||
|
|
||||||
with tf.variable_scope("{}/{}".format(self.ap.full_name_id, name)):
|
variable_scope = "{}/{}".format(self.ap.full_name_id, name)
|
||||||
|
|
||||||
# Global network - the main network shared between threads
|
# Global network - the main network shared between threads
|
||||||
self.global_network = None
|
self.global_network = None
|
||||||
if self.has_global:
|
if self.has_global:
|
||||||
# we assign the parameters of this network on the parameters server
|
# we assign the parameters of this network on the parameters server
|
||||||
with tf.device(replicated_device):
|
self.global_network = general_network(variable_scope=variable_scope,
|
||||||
self.global_network = general_network(agent_parameters=agent_parameters,
|
devices=force_list(replicated_device),
|
||||||
name='{}/global'.format(name),
|
agent_parameters=agent_parameters,
|
||||||
global_network=None,
|
name='{}/global'.format(name),
|
||||||
network_is_local=False,
|
global_network=None,
|
||||||
spaces=spaces,
|
network_is_local=False,
|
||||||
network_is_trainable=True)
|
spaces=spaces,
|
||||||
|
network_is_trainable=True)
|
||||||
|
|
||||||
# Online network - local copy of the main network used for playing
|
# Online network - local copy of the main network used for playing
|
||||||
self.online_network = None
|
self.online_network = None
|
||||||
with tf.device(worker_device):
|
self.online_network = general_network(variable_scope=variable_scope,
|
||||||
self.online_network = general_network(agent_parameters=agent_parameters,
|
devices=force_list(worker_device),
|
||||||
name='{}/online'.format(name),
|
agent_parameters=agent_parameters,
|
||||||
global_network=self.global_network,
|
name='{}/online'.format(name),
|
||||||
network_is_local=True,
|
global_network=self.global_network,
|
||||||
spaces=spaces,
|
network_is_local=True,
|
||||||
network_is_trainable=True)
|
spaces=spaces,
|
||||||
|
network_is_trainable=True)
|
||||||
|
|
||||||
# Target network - a local, slow updating network used for stabilizing the learning
|
# Target network - a local, slow updating network used for stabilizing the learning
|
||||||
self.target_network = None
|
self.target_network = None
|
||||||
if self.has_target:
|
if self.has_target:
|
||||||
with tf.device(worker_device):
|
self.target_network = general_network(variable_scope=variable_scope,
|
||||||
self.target_network = general_network(agent_parameters=agent_parameters,
|
devices=force_list(worker_device),
|
||||||
name='{}/target'.format(name),
|
agent_parameters=agent_parameters,
|
||||||
global_network=self.global_network,
|
name='{}/target'.format(name),
|
||||||
network_is_local=True,
|
global_network=self.global_network,
|
||||||
spaces=spaces,
|
network_is_local=True,
|
||||||
network_is_trainable=False)
|
spaces=spaces,
|
||||||
|
network_is_trainable=False)
|
||||||
|
|
||||||
def sync(self):
|
def sync(self):
|
||||||
"""
|
"""
|
||||||
@@ -198,26 +202,6 @@ class NetworkWrapper(object):
|
|||||||
"""
|
"""
|
||||||
return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
|
return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
|
||||||
|
|
||||||
def get_local_variables(self):
|
|
||||||
"""
|
|
||||||
Get all the variables that are local to the thread
|
|
||||||
|
|
||||||
:return: a list of all the variables that are local to the thread
|
|
||||||
"""
|
|
||||||
local_variables = [v for v in tf.local_variables() if self.online_network.name in v.name]
|
|
||||||
if self.has_target:
|
|
||||||
local_variables += [v for v in tf.local_variables() if self.target_network.name in v.name]
|
|
||||||
return local_variables
|
|
||||||
|
|
||||||
def get_global_variables(self):
|
|
||||||
"""
|
|
||||||
Get all the variables that are shared between threads
|
|
||||||
|
|
||||||
:return: a list of all the variables that are shared between threads
|
|
||||||
"""
|
|
||||||
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
|
|
||||||
return global_variables
|
|
||||||
|
|
||||||
def set_is_training(self, state: bool):
|
def set_is_training(self, state: bool):
|
||||||
"""
|
"""
|
||||||
Set the phase of the network between training and testing
|
Set the phase of the network between training and testing
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -25,8 +25,9 @@ from rl_coach.architectures.head_parameters import HeadParameters
|
|||||||
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
||||||
from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture
|
from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture
|
||||||
from rl_coach.architectures.tensorflow_components import utils
|
from rl_coach.architectures.tensorflow_components import utils
|
||||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType
|
||||||
from rl_coach.core_types import PredictionType
|
from rl_coach.core_types import PredictionType
|
||||||
|
from rl_coach.logger import screen
|
||||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
||||||
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
|
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
|
||||||
|
|
||||||
@@ -35,6 +36,62 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
"""
|
"""
|
||||||
A generalized version of all possible networks implemented using tensorflow.
|
A generalized version of all possible networks implemented using tensorflow.
|
||||||
"""
|
"""
|
||||||
|
# dictionary of variable-scope name to variable-scope object to prevent tensorflow from
|
||||||
|
# creating a new auxiliary variable scope even when name is properly specified
|
||||||
|
variable_scopes_dict = dict()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork':
|
||||||
|
"""
|
||||||
|
Construct a network class using the provided variable scope and on requested devices
|
||||||
|
:param variable_scope: string specifying variable scope under which to create network variables
|
||||||
|
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
|
||||||
|
:param args: all other arguments for class initializer
|
||||||
|
:param kwargs: all other keyword arguments for class initializer
|
||||||
|
:return: a GeneralTensorFlowNetwork object
|
||||||
|
"""
|
||||||
|
if len(devices) > 1:
|
||||||
|
screen.warning("Tensorflow implementation only support a single device. Using {}".format(devices[0]))
|
||||||
|
|
||||||
|
def construct_on_device():
|
||||||
|
with tf.device(GeneralTensorFlowNetwork._tf_device(devices[0])):
|
||||||
|
return GeneralTensorFlowNetwork(*args, **kwargs)
|
||||||
|
|
||||||
|
# If variable_scope is in our dictionary, then this is not the first time that this variable_scope
|
||||||
|
# is being used with construct(). So to avoid TF adding an incrementing number to the end of the
|
||||||
|
# variable_scope to uniquify it, we have to both pass the previous variable_scope object to the new
|
||||||
|
# variable_scope() call and also recover the name space using name_scope
|
||||||
|
if variable_scope in GeneralTensorFlowNetwork.variable_scopes_dict:
|
||||||
|
variable_scope = GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope]
|
||||||
|
with tf.variable_scope(variable_scope, auxiliary_name_scope=False) as vs:
|
||||||
|
with tf.name_scope(vs.original_name_scope):
|
||||||
|
return construct_on_device()
|
||||||
|
else:
|
||||||
|
with tf.variable_scope(variable_scope, auxiliary_name_scope=True) as vs:
|
||||||
|
# Add variable_scope object to dictionary for next call to construct
|
||||||
|
GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] = vs
|
||||||
|
return construct_on_device()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tf_device(device: Union[str, Device]) -> str:
|
||||||
|
"""
|
||||||
|
Convert device to tensorflow-specific device representation
|
||||||
|
:param device: either a specific string (used in distributed mode) which is returned without
|
||||||
|
any change or a Device type
|
||||||
|
:return: tensorflow-specific string for device
|
||||||
|
"""
|
||||||
|
if isinstance(device, str):
|
||||||
|
return device
|
||||||
|
elif isinstance(device, Device):
|
||||||
|
if device.device_type == DeviceType.CPU:
|
||||||
|
return "/cpu:0"
|
||||||
|
elif device.device_type == DeviceType.GPU:
|
||||||
|
return "/device:GPU:{}".format(device.index)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid device_type: {}".format(device.device_type))
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid device instance type: {}".format(type(device)))
|
||||||
|
|
||||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str,
|
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str,
|
||||||
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
|
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -61,6 +61,35 @@ class RunType(Enum):
|
|||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceType(Enum):
|
||||||
|
CPU = 'cpu'
|
||||||
|
GPU = 'gpu'
|
||||||
|
|
||||||
|
|
||||||
|
class Device(object):
|
||||||
|
def __init__(self, device_type: DeviceType, index: int=0):
|
||||||
|
"""
|
||||||
|
:param device_type: type of device (CPU/GPU)
|
||||||
|
:param index: index of device (only used if device type is GPU)
|
||||||
|
"""
|
||||||
|
self._device_type = device_type
|
||||||
|
self._index = index
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device_type(self):
|
||||||
|
return self._device_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index(self):
|
||||||
|
return self._index
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "{}{}".format(self._device_type, self._index)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach.
|
# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach.
|
||||||
# The default value is None, which means the algorithm or preset cannot be used with distributed Coach.
|
# The default value is None, which means the algorithm or preset cannot be used with distributed Coach.
|
||||||
class DistributedCoachSynchronizationType(Enum):
|
class DistributedCoachSynchronizationType(Enum):
|
||||||
@@ -520,7 +549,8 @@ class AgentParameters(Parameters):
|
|||||||
class TaskParameters(Parameters):
|
class TaskParameters(Parameters):
|
||||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
||||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||||
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
|
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False,
|
||||||
|
num_gpu: int=1):
|
||||||
"""
|
"""
|
||||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||||
:param evaluate_only: the task will be used only for evaluating the model
|
:param evaluate_only: the task will be used only for evaluating the model
|
||||||
@@ -532,7 +562,7 @@ class TaskParameters(Parameters):
|
|||||||
:param checkpoint_save_dir: the directory to store the checkpoints in
|
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||||
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
||||||
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
|
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
|
||||||
|
:param num_gpu: number of GPUs to use
|
||||||
"""
|
"""
|
||||||
self.framework_type = framework_type
|
self.framework_type = framework_type
|
||||||
self.task_index = 0 # TODO: not really needed
|
self.task_index = 0 # TODO: not really needed
|
||||||
@@ -545,6 +575,7 @@ class TaskParameters(Parameters):
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.export_onnx_graph = export_onnx_graph
|
self.export_onnx_graph = export_onnx_graph
|
||||||
self.apply_stop_condition = apply_stop_condition
|
self.apply_stop_condition = apply_stop_condition
|
||||||
|
self.num_gpu = num_gpu
|
||||||
|
|
||||||
|
|
||||||
class DistributedTaskParameters(TaskParameters):
|
class DistributedTaskParameters(TaskParameters):
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from typing import List, Dict
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||||
from rl_coach.architectures.tensorflow_components.layers import NoisyNetDense
|
from rl_coach.architectures.layers import NoisyNetDense
|
||||||
from rl_coach.base_parameters import AgentParameters, NetworkParameters
|
from rl_coach.base_parameters import AgentParameters, NetworkParameters
|
||||||
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
|
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from typing import List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats, TFSharedRunningStats
|
|
||||||
from rl_coach.core_types import ObservationType
|
from rl_coach.core_types import ObservationType
|
||||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||||
from rl_coach.spaces import ObservationSpace
|
from rl_coach.spaces import ObservationSpace
|
||||||
@@ -54,6 +53,7 @@ class ObservationNormalizationFilter(ObservationFilter):
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
if mode == 'tf':
|
if mode == 'tf':
|
||||||
|
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
|
||||||
self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False,
|
self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False,
|
||||||
pubsub_params=memory_backend_params)
|
pubsub_params=memory_backend_params)
|
||||||
elif mode == 'numpy':
|
elif mode == 'numpy':
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
|
|
||||||
from rl_coach.core_types import RewardType
|
from rl_coach.core_types import RewardType
|
||||||
from rl_coach.filters.reward.reward_filter import RewardFilter
|
from rl_coach.filters.reward.reward_filter import RewardFilter
|
||||||
from rl_coach.spaces import RewardSpace
|
from rl_coach.spaces import RewardSpace
|
||||||
@@ -48,6 +47,7 @@ class RewardNormalizationFilter(RewardFilter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if mode == 'tf':
|
if mode == 'tf':
|
||||||
|
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
|
||||||
self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False,
|
self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False,
|
||||||
pubsub_params=memory_backend_params)
|
pubsub_params=memory_backend_params)
|
||||||
elif mode == 'numpy':
|
elif mode == 'numpy':
|
||||||
|
|||||||
Reference in New Issue
Block a user