diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index af55b5f..eb5bf62 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -26,7 +26,7 @@ from six.moves import range from rl_coach.agents.agent_interface import AgentInterface from rl_coach.architectures.network_wrapper import NetworkWrapper -from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters, Frameworks +from rl_coach.base_parameters import AgentParameters, Device, DeviceType, DistributedTaskParameters, Frameworks from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse from rl_coach.logger import screen, Logger, EpisodeLogger @@ -98,14 +98,18 @@ class Agent(AgentInterface): self.has_global = True self.replicated_device = agent_parameters.task_parameters.device self.worker_device = "/job:worker/task:{}".format(self.task_id) + if agent_parameters.task_parameters.use_cpu: + self.worker_device += "/cpu:0" + else: + self.worker_device += "/device:GPU:0" else: self.has_global = False self.replicated_device = None - self.worker_device = "" - if agent_parameters.task_parameters.use_cpu: - self.worker_device += "/cpu:0" - else: - self.worker_device += "/device:GPU:0" + if agent_parameters.task_parameters.use_cpu: + self.worker_device = Device(DeviceType.CPU) + else: + self.worker_device = [Device(DeviceType.GPU, i) + for i in range(agent_parameters.task_parameters.num_gpu)] # filters self.input_filter = self.ap.input_filter diff --git a/rl_coach/architectures/architecture.py b/rl_coach/architectures/architecture.py index 637eef6..90dbd6e 100644 --- a/rl_coach/architectures/architecture.py +++ b/rl_coach/architectures/architecture.py @@ -24,6 +24,18 @@ from rl_coach.spaces import SpacesDefinition class Architecture(object): + @staticmethod + def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'Architecture': + """ + Construct a network class using the provided variable scope and on requested devices + :param variable_scope: string specifying variable scope under which to create network variables + :param devices: list of devices (can be list of Device objects, or string for TF distributed) + :param args: all other arguments for class initializer + :param kwargs: all other keyword arguments for class initializer + :return: an object which is a child of Architecture + """ + raise NotImplementedError + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= ""): """ Creates a neural network 'architecture', that can be trained and used for inference. diff --git a/rl_coach/architectures/mxnet_components/architecture.py b/rl_coach/architectures/mxnet_components/architecture.py index 8f4c6a1..0f665d3 100644 --- a/rl_coach/architectures/mxnet_components/architecture.py +++ b/rl_coach/architectures/mxnet_components/architecture.py @@ -32,8 +32,8 @@ from rl_coach.utils import force_list, squeeze_list class MxnetArchitecture(Architecture): - def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "", - global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, devices: List[mx.Context], + name: str = "", global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): """ :param agent_parameters: the agent parameters :param spaces: the spaces definition of the agent @@ -58,6 +58,7 @@ class MxnetArchitecture(Architecture): self.network_is_trainable = network_is_trainable self.is_training = False self.model = None # type: GeneralModel + self._devices = devices self.is_chief = self.ap.task_parameters.task_index == 0 self.network_is_global = not self.network_is_local and global_network is None @@ -75,13 +76,17 @@ class MxnetArchitecture(Architecture): def __str__(self): return self.model.summary(*self._dummy_model_inputs()) - @property - def _model_grads(self) -> Generator[NDArray, NDArray, Any]: + def _model_grads(self, index: int=0) ->\ + Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]: """ Creates a copy of model gradients and returns them in a list, in the same order as collect_params() + :param index: device index. Set to -1 to get a tuple of list of NDArrays for all devices :return: a generator for model gradient values """ - return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null') + if index < 0: + return (p.list_grad() for p in self.model.collect_params().values() if p.grad_req != 'null') + else: + return (p.list_grad()[index] for p in self.model.collect_params().values() if p.grad_req != 'null') def _model_input_shapes(self) -> List[List[int]]: """ @@ -101,7 +106,7 @@ class MxnetArchitecture(Architecture): :return: tuple of inputs for model forward pass """ input_shapes = self._model_input_shapes() - inputs = tuple(nd.zeros(tuple(shape)) for shape in input_shapes) + inputs = tuple(nd.zeros(tuple(shape), ctx=self._devices[0]) for shape in input_shapes) return inputs def construct_model(self) -> None: @@ -117,9 +122,8 @@ class MxnetArchitecture(Architecture): :param sess: must be None """ assert sess is None - # FIXME Add GPU initialization # FIXME Add initializer - self.model.collect_params().initialize(ctx=mx.cpu()) + self.model.collect_params().initialize(ctx=self._devices) # Hybridize model and losses self.model.hybridize() for l in self.losses: @@ -145,7 +149,7 @@ class MxnetArchitecture(Architecture): for a in self.accumulated_gradients: a *= 0 else: - self.accumulated_gradients = [g.copy() for g in self._model_grads] + self.accumulated_gradients = [g.copy() for g in self._model_grads()] def accumulate_gradients(self, inputs: Dict[str, np.ndarray], @@ -175,7 +179,7 @@ class MxnetArchitecture(Architecture): self.reset_accumulated_gradients() embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders] - nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders) + nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders) assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported" @@ -190,7 +194,7 @@ class MxnetArchitecture(Architecture): for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss): l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss) # Align arguments with loss.loss_forward and convert to NDArray - l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss)) + l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss), h_out[0].context) # Calculate loss and all auxiliary outputs loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema) if LOSS_OUT_TYPE_LOSS in loss_outputs: @@ -216,25 +220,26 @@ class MxnetArchitecture(Architecture): # allreduce gradients from all contexts self.trainer.allreduce_grads() + model_grads_cpy = [g.copy() for g in self._model_grads()] # Calculate global norm of gradients # FIXME global norm is returned even when not used for clipping! Is this necessary? # FIXME global norm might be calculated twice if clipping method is global norm - norm_unclipped_grads = utils.global_norm(self._model_grads) + norm_unclipped_grads = utils.global_norm(model_grads_cpy) # Clip gradients if self.network_parameters.clip_gradients: utils.clip_grad( - self._model_grads, + model_grads_cpy, clip_method=self.network_parameters.gradients_clipping_method, clip_val=self.network_parameters.clip_gradients, inplace=True) # Update self.accumulated_gradients depending on no_accumulation flag if no_accumulation: - for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads): + for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy): acc_grad[:] = model_grad else: - for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads): + for acc_grad, model_grad in zip(self.accumulated_gradients, model_grads_cpy): acc_grad += model_grad # result of of additional fetches @@ -269,8 +274,9 @@ class MxnetArchitecture(Architecture): batch_size = self.ap.task_parameters.num_training_tasks # set parameter gradients to gradients passed in - for param_grad, gradient in zip(self._model_grads, gradients): - param_grad[:] = gradient + for param_grad, gradient in zip(self._model_grads(-1), gradients): + for pg in param_grad: + pg[:] = gradient # update gradients self.trainer.update(batch_size=batch_size) @@ -283,7 +289,7 @@ class MxnetArchitecture(Architecture): WARNING: must only call once per state since each call is assumed by LSTM to be a new time step. """ embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders] - nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders) + nd_inputs = tuple(nd.array(inputs[emb], ctx=self._devices[0]) for emb in embedders) assert self.middleware.__class__.__name__ != 'LSTMMiddleware' diff --git a/rl_coach/architectures/mxnet_components/general_network.py b/rl_coach/architectures/mxnet_components/general_network.py index 5ded582..ef79ab3 100644 --- a/rl_coach/architectures/mxnet_components/general_network.py +++ b/rl_coach/architectures/mxnet_components/general_network.py @@ -37,7 +37,7 @@ from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, Ten from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware from rl_coach.architectures.mxnet_components import utils -from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType +from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace @@ -45,9 +45,40 @@ class GeneralMxnetNetwork(MxnetArchitecture): """ A generalized version of all possible networks implemented using mxnet. """ + @staticmethod + def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork': + """ + Construct a network class using the provided variable scope and on requested devices + :param variable_scope: string specifying variable scope under which to create network variables + :param devices: list of devices (can be list of Device objects, or string for TF distributed) + :param args: all other arguments for class initializer + :param kwargs: all other keyword arguments for class initializer + :return: a GeneralTensorFlowNetwork object + """ + return GeneralMxnetNetwork(*args, devices=[GeneralMxnetNetwork._mx_device(d) for d in devices], **kwargs) + + @staticmethod + def _mx_device(device: Union[str, Device]) -> mx.Context: + """ + Convert device to tensorflow-specific device representation + :param device: either a specific string (used in distributed mode) which is returned without + any change or a Device type + :return: tensorflow-specific string for device + """ + if isinstance(device, Device): + if device.device_type == DeviceType.CPU: + return mx.cpu() + elif device.device_type == DeviceType.GPU: + return mx.gpu(device.index) + else: + raise ValueError("Invalid device_type: {}".format(device.device_type)) + else: + raise ValueError("Invalid device instance type: {}".format(type(device))) + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, + devices: List[mx.Context], name: str, global_network=None, network_is_local: bool=True, @@ -55,6 +86,7 @@ class GeneralMxnetNetwork(MxnetArchitecture): """ :param agent_parameters: the agent parameters :param spaces: the spaces definition of the agent + :param devices: list of devices to run the network on :param name: the name of the network :param global_network: the global network replica that is shared between all the workers :param network_is_local: is the network global (shared between workers) or local (dedicated to the worker) @@ -69,7 +101,7 @@ class GeneralMxnetNetwork(MxnetArchitecture): self.num_heads_per_network = len(self.network_parameters.heads_parameters) self.num_networks = 1 - super().__init__(agent_parameters, spaces, name, global_network, + super().__init__(agent_parameters, spaces, devices, name, global_network, network_is_local, network_is_trainable) def construct_model(self): diff --git a/rl_coach/architectures/mxnet_components/utils.py b/rl_coach/architectures/mxnet_components/utils.py index cfd497f..bf243dd 100644 --- a/rl_coach/architectures/mxnet_components/utils.py +++ b/rl_coach/architectures/mxnet_components/utils.py @@ -15,24 +15,26 @@ from rl_coach.core_types import GradientClippingMethod nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] -def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\ +def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float], ctx: mx.Context=None) ->\ Union[List[NDArray], Tuple[NDArray], NDArray]: """ Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or it can be np.ndarray, int, or float :param data: input data to be converted + :param ctx: context of the data (CPU, GPU0, GPU1, etc.) :return: converted output data """ if isinstance(data, list): - data = [to_mx_ndarray(d) for d in data] + data = [to_mx_ndarray(d, ctx=ctx) for d in data] elif isinstance(data, tuple): - data = tuple(to_mx_ndarray(d) for d in data) + data = tuple(to_mx_ndarray(d, ctx=ctx) for d in data) elif isinstance(data, np.ndarray): - data = nd.array(data) + data = nd.array(data, ctx=ctx) elif isinstance(data, NDArray): + assert data.context == ctx pass elif isinstance(data, int) or isinstance(data, float): - data = nd.array([data]) + data = nd.array([data], ctx=ctx) else: raise TypeError('Unsupported data type: {}'.format(type(data))) return data diff --git a/rl_coach/architectures/network_wrapper.py b/rl_coach/architectures/network_wrapper.py index 61a3d14..644a151 100644 --- a/rl_coach/architectures/network_wrapper.py +++ b/rl_coach/architectures/network_wrapper.py @@ -20,6 +20,7 @@ from rl_coach.base_parameters import Frameworks, AgentParameters from rl_coach.logger import failed_imports from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition +from rl_coach.utils import force_list try: import tensorflow as tf from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork @@ -53,52 +54,55 @@ class NetworkWrapper(object): if self.network_parameters.framework == Frameworks.tensorflow: if "tensorflow" not in failed_imports: - general_network = GeneralTensorFlowNetwork + general_network = GeneralTensorFlowNetwork.construct else: raise Exception('Install tensorflow before using it as framework') elif self.network_parameters.framework == Frameworks.mxnet: if "mxnet" not in failed_imports: - general_network = GeneralMxnetNetwork + general_network = GeneralMxnetNetwork.construct else: raise Exception('Install mxnet before using it as framework') else: raise Exception("{} Framework is not supported" .format(Frameworks().to_string(self.network_parameters.framework))) - with tf.variable_scope("{}/{}".format(self.ap.full_name_id, name)): + variable_scope = "{}/{}".format(self.ap.full_name_id, name) - # Global network - the main network shared between threads - self.global_network = None - if self.has_global: - # we assign the parameters of this network on the parameters server - with tf.device(replicated_device): - self.global_network = general_network(agent_parameters=agent_parameters, - name='{}/global'.format(name), - global_network=None, - network_is_local=False, - spaces=spaces, - network_is_trainable=True) + # Global network - the main network shared between threads + self.global_network = None + if self.has_global: + # we assign the parameters of this network on the parameters server + self.global_network = general_network(variable_scope=variable_scope, + devices=force_list(replicated_device), + agent_parameters=agent_parameters, + name='{}/global'.format(name), + global_network=None, + network_is_local=False, + spaces=spaces, + network_is_trainable=True) - # Online network - local copy of the main network used for playing - self.online_network = None - with tf.device(worker_device): - self.online_network = general_network(agent_parameters=agent_parameters, - name='{}/online'.format(name), - global_network=self.global_network, - network_is_local=True, - spaces=spaces, - network_is_trainable=True) + # Online network - local copy of the main network used for playing + self.online_network = None + self.online_network = general_network(variable_scope=variable_scope, + devices=force_list(worker_device), + agent_parameters=agent_parameters, + name='{}/online'.format(name), + global_network=self.global_network, + network_is_local=True, + spaces=spaces, + network_is_trainable=True) - # Target network - a local, slow updating network used for stabilizing the learning - self.target_network = None - if self.has_target: - with tf.device(worker_device): - self.target_network = general_network(agent_parameters=agent_parameters, - name='{}/target'.format(name), - global_network=self.global_network, - network_is_local=True, - spaces=spaces, - network_is_trainable=False) + # Target network - a local, slow updating network used for stabilizing the learning + self.target_network = None + if self.has_target: + self.target_network = general_network(variable_scope=variable_scope, + devices=force_list(worker_device), + agent_parameters=agent_parameters, + name='{}/target'.format(name), + global_network=self.global_network, + network_is_local=True, + spaces=spaces, + network_is_trainable=False) def sync(self): """ @@ -198,26 +202,6 @@ class NetworkWrapper(object): """ return type(self.online_network).parallel_predict(self.sess, network_input_tuples) - def get_local_variables(self): - """ - Get all the variables that are local to the thread - - :return: a list of all the variables that are local to the thread - """ - local_variables = [v for v in tf.local_variables() if self.online_network.name in v.name] - if self.has_target: - local_variables += [v for v in tf.local_variables() if self.target_network.name in v.name] - return local_variables - - def get_global_variables(self): - """ - Get all the variables that are shared between threads - - :return: a list of all the variables that are shared between threads - """ - global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name] - return global_variables - def set_is_training(self, state: bool): """ Set the phase of the network between training and testing diff --git a/rl_coach/architectures/tensorflow_components/general_network.py b/rl_coach/architectures/tensorflow_components/general_network.py index 28b9c60..cecc157 100644 --- a/rl_coach/architectures/tensorflow_components/general_network.py +++ b/rl_coach/architectures/tensorflow_components/general_network.py @@ -15,7 +15,7 @@ # import copy -from typing import Dict +from typing import Dict, List, Union import numpy as np import tensorflow as tf @@ -25,8 +25,9 @@ from rl_coach.architectures.head_parameters import HeadParameters from rl_coach.architectures.middleware_parameters import MiddlewareParameters from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture from rl_coach.architectures.tensorflow_components import utils -from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType +from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType from rl_coach.core_types import PredictionType +from rl_coach.logger import screen from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string @@ -35,6 +36,62 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): """ A generalized version of all possible networks implemented using tensorflow. """ + # dictionary of variable-scope name to variable-scope object to prevent tensorflow from + # creating a new auxiliary variable scope even when name is properly specified + variable_scopes_dict = dict() + + @staticmethod + def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork': + """ + Construct a network class using the provided variable scope and on requested devices + :param variable_scope: string specifying variable scope under which to create network variables + :param devices: list of devices (can be list of Device objects, or string for TF distributed) + :param args: all other arguments for class initializer + :param kwargs: all other keyword arguments for class initializer + :return: a GeneralTensorFlowNetwork object + """ + if len(devices) > 1: + screen.warning("Tensorflow implementation only support a single device. Using {}".format(devices[0])) + + def construct_on_device(): + with tf.device(GeneralTensorFlowNetwork._tf_device(devices[0])): + return GeneralTensorFlowNetwork(*args, **kwargs) + + # If variable_scope is in our dictionary, then this is not the first time that this variable_scope + # is being used with construct(). So to avoid TF adding an incrementing number to the end of the + # variable_scope to uniquify it, we have to both pass the previous variable_scope object to the new + # variable_scope() call and also recover the name space using name_scope + if variable_scope in GeneralTensorFlowNetwork.variable_scopes_dict: + variable_scope = GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] + with tf.variable_scope(variable_scope, auxiliary_name_scope=False) as vs: + with tf.name_scope(vs.original_name_scope): + return construct_on_device() + else: + with tf.variable_scope(variable_scope, auxiliary_name_scope=True) as vs: + # Add variable_scope object to dictionary for next call to construct + GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] = vs + return construct_on_device() + + @staticmethod + def _tf_device(device: Union[str, Device]) -> str: + """ + Convert device to tensorflow-specific device representation + :param device: either a specific string (used in distributed mode) which is returned without + any change or a Device type + :return: tensorflow-specific string for device + """ + if isinstance(device, str): + return device + elif isinstance(device, Device): + if device.device_type == DeviceType.CPU: + return "/cpu:0" + elif device.device_type == DeviceType.GPU: + return "/device:GPU:{}".format(device.index) + else: + raise ValueError("Invalid device_type: {}".format(device.device_type)) + else: + raise ValueError("Invalid device instance type: {}".format(type(device))) + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str, global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): """ diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index 5ba65f8..d3b5999 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -61,6 +61,35 @@ class RunType(Enum): return self.value +class DeviceType(Enum): + CPU = 'cpu' + GPU = 'gpu' + + +class Device(object): + def __init__(self, device_type: DeviceType, index: int=0): + """ + :param device_type: type of device (CPU/GPU) + :param index: index of device (only used if device type is GPU) + """ + self._device_type = device_type + self._index = index + + @property + def device_type(self): + return self._device_type + + @property + def index(self): + return self._index + + def __str__(self): + return "{}{}".format(self._device_type, self._index) + + def __repr__(self): + return str(self) + + # DistributedCoachSynchronizationType provides the synchronization type for distributed Coach. # The default value is None, which means the algorithm or preset cannot be used with distributed Coach. class DistributedCoachSynchronizationType(Enum): @@ -520,7 +549,8 @@ class AgentParameters(Parameters): class TaskParameters(Parameters): def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False, experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, - checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False): + checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False, + num_gpu: int=1): """ :param framework_type: deep learning framework type. currently only tensorflow is supported :param evaluate_only: the task will be used only for evaluating the model @@ -532,7 +562,7 @@ class TaskParameters(Parameters): :param checkpoint_save_dir: the directory to store the checkpoints in :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate - + :param num_gpu: number of GPUs to use """ self.framework_type = framework_type self.task_index = 0 # TODO: not really needed @@ -545,6 +575,7 @@ class TaskParameters(Parameters): self.seed = seed self.export_onnx_graph = export_onnx_graph self.apply_stop_condition = apply_stop_condition + self.num_gpu = num_gpu class DistributedTaskParameters(TaskParameters): diff --git a/rl_coach/exploration_policies/parameter_noise.py b/rl_coach/exploration_policies/parameter_noise.py index fabbe65..34381c4 100644 --- a/rl_coach/exploration_policies/parameter_noise.py +++ b/rl_coach/exploration_policies/parameter_noise.py @@ -19,7 +19,7 @@ from typing import List, Dict import numpy as np from rl_coach.agents.dqn_agent import DQNAgentParameters -from rl_coach.architectures.tensorflow_components.layers import NoisyNetDense +from rl_coach.architectures.layers import NoisyNetDense from rl_coach.base_parameters import AgentParameters, NetworkParameters from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace diff --git a/rl_coach/filters/observation/observation_normalization_filter.py b/rl_coach/filters/observation/observation_normalization_filter.py index 219fc0e..6edd6c9 100644 --- a/rl_coach/filters/observation/observation_normalization_filter.py +++ b/rl_coach/filters/observation/observation_normalization_filter.py @@ -19,7 +19,6 @@ from typing import List import numpy as np -from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats, TFSharedRunningStats from rl_coach.core_types import ObservationType from rl_coach.filters.observation.observation_filter import ObservationFilter from rl_coach.spaces import ObservationSpace @@ -54,6 +53,7 @@ class ObservationNormalizationFilter(ObservationFilter): :return: None """ if mode == 'tf': + from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False, pubsub_params=memory_backend_params) elif mode == 'numpy': diff --git a/rl_coach/filters/reward/reward_normalization_filter.py b/rl_coach/filters/reward/reward_normalization_filter.py index c6c489c..1541966 100644 --- a/rl_coach/filters/reward/reward_normalization_filter.py +++ b/rl_coach/filters/reward/reward_normalization_filter.py @@ -17,7 +17,6 @@ import os import numpy as np -from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats from rl_coach.core_types import RewardType from rl_coach.filters.reward.reward_filter import RewardFilter from rl_coach.spaces import RewardSpace @@ -48,6 +47,7 @@ class RewardNormalizationFilter(RewardFilter): """ if mode == 'tf': + from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False, pubsub_params=memory_backend_params) elif mode == 'numpy':