From 93571306c37de65eaecbf79fa77b75bfb0b1f68b Mon Sep 17 00:00:00 2001 From: Sina Afrooze Date: Tue, 6 Nov 2018 07:39:29 -0800 Subject: [PATCH] Removed tensorflow specific code in presets (#59) * Add generic layer specification for using in presets * Modify presets to use the generic scheme --- rl_coach/architectures/layers.py | 78 ++++++++++++ .../embedders/embedder.py | 7 +- .../tensorflow_components/general_network.py | 28 +---- .../tensorflow_components/layers.py | 117 +++++++++++------- .../middlewares/middleware.py | 7 +- .../tensorflow_components/utils.py | 40 ++++++ rl_coach/presets/Atari_NStepQ.py | 2 +- rl_coach/presets/BitFlip_DQN.py | 2 +- rl_coach/presets/BitFlip_DQN_HER.py | 2 +- rl_coach/presets/CARLA_CIL.py | 33 +++-- rl_coach/presets/CartPole_PPO.py | 2 +- rl_coach/presets/ControlSuite_DDPG.py | 2 +- rl_coach/presets/Fetch_DDPG_HER_baselines.py | 2 +- rl_coach/presets/Mujoco_A3C_LSTM.py | 2 +- rl_coach/presets/Mujoco_ClippedPPO.py | 2 +- rl_coach/presets/Mujoco_DDPG.py | 2 +- rl_coach/presets/Mujoco_NAF.py | 2 +- rl_coach/presets/Mujoco_PPO.py | 2 +- rl_coach/presets/Pendulum_HAC.py | 2 +- 19 files changed, 233 insertions(+), 101 deletions(-) create mode 100644 rl_coach/architectures/layers.py create mode 100644 rl_coach/architectures/tensorflow_components/utils.py diff --git a/rl_coach/architectures/layers.py b/rl_coach/architectures/layers.py new file mode 100644 index 0000000..e295199 --- /dev/null +++ b/rl_coach/architectures/layers.py @@ -0,0 +1,78 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Module implementing base classes for common network layers used by preset schemes +""" + + +class Conv2d(object): + """ + Base class for framework specfic Conv2d layer + """ + def __init__(self, num_filters: int, kernel_size: int, strides: int): + self.num_filters = num_filters + self.kernel_size = kernel_size + self.strides = strides + + def __str__(self): + return "Convolution (num filters = {}, kernel size = {}, stride = {})"\ + .format(self.num_filters, self.kernel_size, self.strides) + + +class BatchnormActivationDropout(object): + """ + Base class for framework specific batchnorm->activation->dropout layer group + """ + def __init__(self, batchnorm: bool=False, activation_function: str=None, dropout_rate: float=0): + self.batchnorm = batchnorm + self.activation_function = activation_function + self.dropout_rate = dropout_rate + + def __str__(self): + result = [] + if self.batchnorm: + result += ["Batch Normalization"] + if self.activation_function: + result += ["Activation (type = {})".format(self.activation_function)] + if self.dropout_rate > 0: + result += ["Dropout (rate = {})".format(self.dropout_rate)] + return "\n".join(result) + + +class Dense(object): + """ + Base class for framework specific Dense layer + """ + def __init__(self, units: int): + self.units = units + + def __str__(self): + return "Dense (num outputs = {})".format(self.units) + + +class NoisyNetDense(object): + """ + Base class for framework specific factorized Noisy Net layer + + https://arxiv.org/abs/1706.10295. + """ + + def __init__(self, units: int): + self.units = units + self.sigma0 = 0.5 + + def __str__(self): + return "Noisy Dense (num outputs = {})".format(self.units) diff --git a/rl_coach/architectures/tensorflow_components/embedders/embedder.py b/rl_coach/architectures/tensorflow_components/embedders/embedder.py index 004c5c4..967b1ba 100644 --- a/rl_coach/architectures/tensorflow_components/embedders/embedder.py +++ b/rl_coach/architectures/tensorflow_components/embedders/embedder.py @@ -20,8 +20,7 @@ import copy import numpy as np import tensorflow as tf -from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense, \ - BatchnormActivationDropout +from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters from rl_coach.core_types import InputEmbedding @@ -62,7 +61,9 @@ class InputEmbedder(object): if isinstance(self.scheme, EmbedderScheme): self.layers_params = copy.copy(self.schemes[self.scheme]) else: - self.layers_params = copy.copy(self.scheme) + # if scheme is specified directly, convert to TF layer if it's not a callable object + # NOTE: if layer object is callable, it must return a TF tensor when invoked + self.layers_params = [convert_layer(l) for l in copy.copy(self.scheme)] # we allow adding batchnorm, dropout or activation functions after each layer. # The motivation is to simplify the transition between a network with batchnorm and a network without diff --git a/rl_coach/architectures/tensorflow_components/general_network.py b/rl_coach/architectures/tensorflow_components/general_network.py index fa494b5..143cebd 100644 --- a/rl_coach/architectures/tensorflow_components/general_network.py +++ b/rl_coach/architectures/tensorflow_components/general_network.py @@ -24,6 +24,7 @@ from rl_coach.architectures.embedder_parameters import InputEmbedderParameters from rl_coach.architectures.head_parameters import HeadParameters from rl_coach.architectures.middleware_parameters import MiddlewareParameters from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture +from rl_coach.architectures.tensorflow_components import utils from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType from rl_coach.core_types import PredictionType from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace @@ -99,27 +100,6 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): return ret_dict - @staticmethod - def get_activation_function(activation_function_string: str): - """ - Map the activation function from a string to the tensorflow framework equivalent - :param activation_function_string: the type of the activation function - :return: the tensorflow activation function - """ - activation_functions = { - 'relu': tf.nn.relu, - 'tanh': tf.nn.tanh, - 'sigmoid': tf.nn.sigmoid, - 'elu': tf.nn.elu, - 'selu': tf.nn.selu, - 'leaky_relu': tf.nn.leaky_relu, - 'none': None - } - assert activation_function_string in activation_functions.keys(), \ - "Activation function must be one of the following {}. instead it was: {}"\ - .format(activation_functions.keys(), activation_function_string) - return activation_functions[activation_function_string] - def get_input_embedder(self, input_name: str, embedder_params: InputEmbedderParameters): """ Given an input embedder parameters class, creates the input embedder and returns it @@ -144,7 +124,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): embedder_path = 'rl_coach.architectures.tensorflow_components.embedders:' + mod_names[emb_type] embedder_params_copy = copy.copy(embedder_params) - embedder_params_copy.activation_function = self.get_activation_function(embedder_params.activation_function) + embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function) embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type] embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type] embedder_params_copy.name = input_name @@ -162,7 +142,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): mod_name = middleware_params.parameterized_class_name middleware_path = 'rl_coach.architectures.tensorflow_components.middlewares:' + mod_name middleware_params_copy = copy.copy(middleware_params) - middleware_params_copy.activation_function = self.get_activation_function(middleware_params.activation_function) + middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function) module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path) return module @@ -176,7 +156,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): mod_name = head_params.parameterized_class_name head_path = 'rl_coach.architectures.tensorflow_components.heads:' + mod_name head_params_copy = copy.copy(head_params) - head_params_copy.activation_function = self.get_activation_function(head_params_copy.activation_function) + head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function) return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={ 'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name, 'head_idx': head_idx, 'is_local': self.network_is_local}) diff --git a/rl_coach/architectures/tensorflow_components/layers.py b/rl_coach/architectures/tensorflow_components/layers.py index a75ed6c..f9dc356 100644 --- a/rl_coach/architectures/tensorflow_components/layers.py +++ b/rl_coach/architectures/tensorflow_components/layers.py @@ -1,9 +1,11 @@ import math -from typing import List, Union +from types import FunctionType +from typing import Any import tensorflow as tf -from rl_coach.utils import force_list +from rl_coach.architectures import layers +from rl_coach.architectures.tensorflow_components import utils def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout, dropout_rate, is_training, name): @@ -17,6 +19,8 @@ def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dr # activation if activation_function: + if isinstance(activation_function, str): + activation_function = utils.get_activation_function(activation_function) layers.append( activation_function(layers[-1], name="{}_activation".format(name)) ) @@ -33,11 +37,35 @@ def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dr return layers -class Conv2d(object): +# define global dictionary for storing layer type to layer implementation mapping +tf_layer_dict = dict() + + +def reg_to_tf(layer_type) -> FunctionType: + """ function decorator that registers layer implementation + :return: decorated function + """ + def reg_impl_decorator(func): + assert layer_type not in tf_layer_dict + tf_layer_dict[layer_type] = func + return func + return reg_impl_decorator + + +def convert_layer(layer): + """ + If layer is callable, return layer, otherwise convert to TF type + :param layer: layer to be converted + :return: converted layer if not callable, otherwise layer itself + """ + if callable(layer): + return layer + return tf_layer_dict[type(layer)](layer) + + +class Conv2d(layers.Conv2d): def __init__(self, num_filters: int, kernel_size: int, strides: int): - self.num_filters = num_filters - self.kernel_size = kernel_size - self.strides = strides + super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides) def __call__(self, input_layer, name: str=None, is_training=None): """ @@ -49,16 +77,19 @@ class Conv2d(object): return tf.layers.conv2d(input_layer, filters=self.num_filters, kernel_size=self.kernel_size, strides=self.strides, data_format='channels_last', name=name) - def __str__(self): - return "Convolution (num filters = {}, kernel size = {}, stride = {})"\ - .format(self.num_filters, self.kernel_size, self.strides) + @staticmethod + @reg_to_tf(layers.Conv2d) + def to_tf(base: layers.Conv2d): + return Conv2d( + num_filters=base.num_filters, + kernel_size=base.kernel_size, + strides=base.strides) -class BatchnormActivationDropout(object): +class BatchnormActivationDropout(layers.BatchnormActivationDropout): def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0): - self.batchnorm = batchnorm - self.activation_function = activation_function - self.dropout_rate = dropout_rate + super(BatchnormActivationDropout, self).__init__( + batchnorm=batchnorm, activation_function=activation_function, dropout_rate=dropout_rate) def __call__(self, input_layer, name: str=None, is_training=None): """ @@ -72,20 +103,18 @@ class BatchnormActivationDropout(object): dropout=self.dropout_rate > 0, dropout_rate=self.dropout_rate, is_training=is_training, name=name) - def __str__(self): - result = [] - if self.batchnorm: - result += ["Batch Normalization"] - if self.activation_function: - result += ["Activation (type = {})".format(self.activation_function.__name__)] - if self.dropout_rate > 0: - result += ["Dropout (rate = {})".format(self.dropout_rate)] - return "\n".join(result) + @staticmethod + @reg_to_tf(layers.BatchnormActivationDropout) + def to_tf(base: layers.BatchnormActivationDropout): + return BatchnormActivationDropout( + batchnorm=base.batchnorm, + activation_function=base.activation_function, + dropout_rate=base.dropout_rate) -class Dense(object): +class Dense(layers.Dense): def __init__(self, units: int): - self.units = units + super(Dense, self).__init__(units=units) def __call__(self, input_layer, name: str=None, kernel_initializer=None, activation=None, is_training=None): """ @@ -97,11 +126,13 @@ class Dense(object): return tf.layers.dense(input_layer, self.units, name=name, kernel_initializer=kernel_initializer, activation=activation) - def __str__(self): - return "Dense (num outputs = {})".format(self.units) + @staticmethod + @reg_to_tf(layers.Dense) + def to_tf(base: layers.Dense): + return Dense(units=base.units) -class NoisyNetDense(object): +class NoisyNetDense(layers.NoisyNetDense): """ A factorized Noisy Net layer @@ -109,8 +140,7 @@ class NoisyNetDense(object): """ def __init__(self, units: int): - self.units = units - self.sigma0 = 0.5 + super(NoisyNetDense, self).__init__(units=units) def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None, is_training=None): """ @@ -125,6 +155,16 @@ class NoisyNetDense(object): # forward (either act or train, both for online and target networks). # A3C, on the other hand, should sample noise only when policy changes (i.e. after every t_max steps) + def _f(values): + return tf.sqrt(tf.abs(values)) * tf.sign(values) + + def _factorized_noise(inputs, outputs): + # TODO: use factorized noise only for compute intensive algos (e.g. DQN). + # lighter algos (e.g. DQN) should not use it + noise1 = _f(tf.random_normal((inputs, 1))) + noise2 = _f(tf.random_normal((1, outputs))) + return tf.matmul(noise1, noise2) + num_inputs = input_layer.get_shape()[-1].value num_outputs = self.units @@ -145,23 +185,14 @@ class NoisyNetDense(object): initializer=kernel_stddev_initializer) bias_stddev = tf.get_variable('bias_stddev', shape=(num_outputs,), initializer=kernel_stddev_initializer) - bias_noise = self.f(tf.random_normal((num_outputs,))) - weight_noise = self.factorized_noise(num_inputs, num_outputs) + bias_noise = _f(tf.random_normal((num_outputs,))) + weight_noise = _factorized_noise(num_inputs, num_outputs) bias = bias_mean + bias_stddev * bias_noise weight = weight_mean + weight_stddev * weight_noise return activation(tf.matmul(input_layer, weight) + bias) - def factorized_noise(self, inputs, outputs): - # TODO: use factorized noise only for compute intensive algos (e.g. DQN). - # lighter algos (e.g. DQN) should not use it - noise1 = self.f(tf.random_normal((inputs, 1))) - noise2 = self.f(tf.random_normal((1, outputs))) - return tf.matmul(noise1, noise2) - @staticmethod - def f(values): - return tf.sqrt(tf.abs(values)) * tf.sign(values) - - def __str__(self): - return "Noisy Dense (num outputs = {})".format(self.units) + @reg_to_tf(layers.NoisyNetDense) + def to_tf(base: layers.NoisyNetDense): + return NoisyNetDense(units=base.units) diff --git a/rl_coach/architectures/tensorflow_components/middlewares/middleware.py b/rl_coach/architectures/tensorflow_components/middlewares/middleware.py index 02376de..bb10ea9 100644 --- a/rl_coach/architectures/tensorflow_components/middlewares/middleware.py +++ b/rl_coach/architectures/tensorflow_components/middlewares/middleware.py @@ -14,10 +14,11 @@ # limitations under the License. # import copy +from typing import Union import tensorflow as tf -from rl_coach.architectures.tensorflow_components.layers import Dense, BatchnormActivationDropout +from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense from rl_coach.base_parameters import MiddlewareScheme, NetworkComponentParameters from rl_coach.core_types import MiddlewareEmbedding @@ -50,7 +51,9 @@ class Middleware(object): if isinstance(self.scheme, MiddlewareScheme): self.layers_params = copy.copy(self.schemes[self.scheme]) else: - self.layers_params = copy.copy(self.scheme) + # if scheme is specified directly, convert to TF layer if it's not a callable object + # NOTE: if layer object is callable, it must return a TF tensor when invoked + self.layers_params = [convert_layer(l) for l in copy.copy(self.scheme)] # we allow adding batchnorm, dropout or activation functions after each layer. # The motivation is to simplify the transition between a network with batchnorm and a network without diff --git a/rl_coach/architectures/tensorflow_components/utils.py b/rl_coach/architectures/tensorflow_components/utils.py new file mode 100644 index 0000000..749a0ab --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/utils.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Module containing utility functions +""" +import tensorflow as tf + + +def get_activation_function(activation_function_string: str): + """ + Map the activation function from a string to the tensorflow framework equivalent + :param activation_function_string: the type of the activation function + :return: the tensorflow activation function + """ + activation_functions = { + 'relu': tf.nn.relu, + 'tanh': tf.nn.tanh, + 'sigmoid': tf.nn.sigmoid, + 'elu': tf.nn.elu, + 'selu': tf.nn.selu, + 'leaky_relu': tf.nn.leaky_relu, + 'none': None + } + assert activation_function_string in activation_functions.keys(), \ + "Activation function must be one of the following {}. instead it was: {}" \ + .format(activation_functions.keys(), activation_function_string) + return activation_functions[activation_function_string] diff --git a/rl_coach/presets/Atari_NStepQ.py b/rl_coach/presets/Atari_NStepQ.py index 6807f86..4b86d80 100644 --- a/rl_coach/presets/Atari_NStepQ.py +++ b/rl_coach/presets/Atari_NStepQ.py @@ -1,5 +1,5 @@ from rl_coach.agents.n_step_q_agent import NStepQAgentParameters -from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense +from rl_coach.architectures.layers import Conv2d, Dense from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.environment import SingleLevelSelection diff --git a/rl_coach/presets/BitFlip_DQN.py b/rl_coach/presets/BitFlip_DQN.py index ed849f3..5b1c0ec 100644 --- a/rl_coach/presets/BitFlip_DQN.py +++ b/rl_coach/presets/BitFlip_DQN.py @@ -1,6 +1,6 @@ from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.architectures.embedder_parameters import InputEmbedderParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, \ PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps diff --git a/rl_coach/presets/BitFlip_DQN_HER.py b/rl_coach/presets/BitFlip_DQN_HER.py index 3d9c2f9..8f1d4e1 100644 --- a/rl_coach/presets/BitFlip_DQN_HER.py +++ b/rl_coach/presets/BitFlip_DQN_HER.py @@ -1,6 +1,6 @@ from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.architectures.embedder_parameters import InputEmbedderParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, \ PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps diff --git a/rl_coach/presets/CARLA_CIL.py b/rl_coach/presets/CARLA_CIL.py index 8477cdf..4d5efff 100644 --- a/rl_coach/presets/CARLA_CIL.py +++ b/rl_coach/presets/CARLA_CIL.py @@ -1,7 +1,6 @@ import os import numpy as np -import tensorflow as tf # make sure you have $CARLA_ROOT/PythonClient in your PYTHONPATH from carla.driving_benchmark.experiment_suites import CoRL2017 from rl_coach.logger import screen @@ -10,7 +9,7 @@ from rl_coach.agents.cil_agent import CILAgentParameters from rl_coach.architectures.embedder_parameters import InputEmbedderParameters from rl_coach.architectures.head_parameters import RegressionHeadParameters from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters -from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense, BatchnormActivationDropout +from rl_coach.architectures.layers import Conv2d, Dense, BatchnormActivationDropout from rl_coach.base_parameters import VisualizationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.carla_environment import CarlaEnvironmentParameters @@ -46,34 +45,34 @@ agent_params.network_wrappers['main'].input_embedders_parameters = { 'CameraRGB': InputEmbedderParameters( scheme=[ Conv2d(32, 5, 2), - BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh), + BatchnormActivationDropout(batchnorm=True, activation_function='tanh'), Conv2d(32, 3, 1), - BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh), + BatchnormActivationDropout(batchnorm=True, activation_function='tanh'), Conv2d(64, 3, 2), - BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh), + BatchnormActivationDropout(batchnorm=True, activation_function='tanh'), Conv2d(64, 3, 1), - BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh), + BatchnormActivationDropout(batchnorm=True, activation_function='tanh'), Conv2d(128, 3, 2), - BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh), + BatchnormActivationDropout(batchnorm=True, activation_function='tanh'), Conv2d(128, 3, 1), - BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh), + BatchnormActivationDropout(batchnorm=True, activation_function='tanh'), Conv2d(256, 3, 1), - BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh), + BatchnormActivationDropout(batchnorm=True, activation_function='tanh'), Conv2d(256, 3, 1), - BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh), + BatchnormActivationDropout(batchnorm=True, activation_function='tanh'), Dense(512), - BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.3), + BatchnormActivationDropout(activation_function='tanh', dropout_rate=0.3), Dense(512), - BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.3) + BatchnormActivationDropout(activation_function='tanh', dropout_rate=0.3) ], activation_function='none' # we define the activation function for each layer explicitly ), 'measurements': InputEmbedderParameters( scheme=[ Dense(128), - BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5), + BatchnormActivationDropout(activation_function='tanh', dropout_rate=0.5), Dense(128), - BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5) + BatchnormActivationDropout(activation_function='tanh', dropout_rate=0.5) ], activation_function='none' # we define the activation function for each layer explicitly ) @@ -84,7 +83,7 @@ agent_params.network_wrappers['main'].middleware_parameters = \ FCMiddlewareParameters( scheme=[ Dense(512), - BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5) + BatchnormActivationDropout(activation_function='tanh', dropout_rate=0.5) ], activation_function='none' ) @@ -94,9 +93,9 @@ agent_params.network_wrappers['main'].heads_parameters = [ RegressionHeadParameters( scheme=[ Dense(256), - BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5), + BatchnormActivationDropout(activation_function='tanh', dropout_rate=0.5), Dense(256), - BatchnormActivationDropout(activation_function=tf.tanh) + BatchnormActivationDropout(activation_function='tanh') ], num_output_head_copies=4 # follow lane, left, right, straight ) diff --git a/rl_coach/presets/CartPole_PPO.py b/rl_coach/presets/CartPole_PPO.py index 40e06d1..e8af4c1 100644 --- a/rl_coach/presets/CartPole_PPO.py +++ b/rl_coach/presets/CartPole_PPO.py @@ -1,5 +1,5 @@ from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 diff --git a/rl_coach/presets/ControlSuite_DDPG.py b/rl_coach/presets/ControlSuite_DDPG.py index 260ca52..bef6898 100644 --- a/rl_coach/presets/ControlSuite_DDPG.py +++ b/rl_coach/presets/ControlSuite_DDPG.py @@ -1,5 +1,5 @@ from rl_coach.agents.ddpg_agent import DDPGAgentParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.control_suite_environment import ControlSuiteEnvironmentParameters, control_suite_envs diff --git a/rl_coach/presets/Fetch_DDPG_HER_baselines.py b/rl_coach/presets/Fetch_DDPG_HER_baselines.py index d3fa643..c2f17d7 100644 --- a/rl_coach/presets/Fetch_DDPG_HER_baselines.py +++ b/rl_coach/presets/Fetch_DDPG_HER_baselines.py @@ -1,7 +1,7 @@ from rl_coach.agents.ddpg_agent import DDPGAgentParameters from rl_coach.architectures.embedder_parameters import InputEmbedderParameters from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, TrainingSteps from rl_coach.environments.environment import SingleLevelSelection diff --git a/rl_coach/presets/Mujoco_A3C_LSTM.py b/rl_coach/presets/Mujoco_A3C_LSTM.py index 1027c01..323f28c 100644 --- a/rl_coach/presets/Mujoco_A3C_LSTM.py +++ b/rl_coach/presets/Mujoco_A3C_LSTM.py @@ -1,7 +1,7 @@ from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters from rl_coach.architectures.embedder_parameters import InputEmbedderParameters from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.environment import SingleLevelSelection diff --git a/rl_coach/presets/Mujoco_ClippedPPO.py b/rl_coach/presets/Mujoco_ClippedPPO.py index d99f7ac..ca2d662 100644 --- a/rl_coach/presets/Mujoco_ClippedPPO.py +++ b/rl_coach/presets/Mujoco_ClippedPPO.py @@ -1,5 +1,5 @@ from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.environment import SingleLevelSelection diff --git a/rl_coach/presets/Mujoco_DDPG.py b/rl_coach/presets/Mujoco_DDPG.py index 908039c..03a95a8 100644 --- a/rl_coach/presets/Mujoco_DDPG.py +++ b/rl_coach/presets/Mujoco_DDPG.py @@ -1,5 +1,5 @@ from rl_coach.agents.ddpg_agent import DDPGAgentParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, EmbedderScheme from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.environment import SingleLevelSelection diff --git a/rl_coach/presets/Mujoco_NAF.py b/rl_coach/presets/Mujoco_NAF.py index 44d9262..6db8f66 100644 --- a/rl_coach/presets/Mujoco_NAF.py +++ b/rl_coach/presets/Mujoco_NAF.py @@ -1,5 +1,5 @@ from rl_coach.agents.naf_agent import NAFAgentParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, GradientClippingMethod from rl_coach.environments.environment import SingleLevelSelection diff --git a/rl_coach/presets/Mujoco_PPO.py b/rl_coach/presets/Mujoco_PPO.py index 75aae60..4eb2f72 100644 --- a/rl_coach/presets/Mujoco_PPO.py +++ b/rl_coach/presets/Mujoco_PPO.py @@ -1,5 +1,5 @@ from rl_coach.agents.ppo_agent import PPOAgentParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.environment import SingleLevelSelection diff --git a/rl_coach/presets/Pendulum_HAC.py b/rl_coach/presets/Pendulum_HAC.py index b6fa02c..d945f34 100644 --- a/rl_coach/presets/Pendulum_HAC.py +++ b/rl_coach/presets/Pendulum_HAC.py @@ -2,7 +2,7 @@ import numpy as np from rl_coach.agents.hac_ddpg_agent import HACDDPGAgentParameters from rl_coach.architectures.embedder_parameters import InputEmbedderParameters -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.base_parameters import VisualizationParameters, EmbeddingMergerType, EmbedderScheme from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, TrainingSteps from rl_coach.environments.gym_environment import GymVectorEnvironment