1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

parameter noise exploration - using Noisy Nets

This commit is contained in:
Gal Leibovich
2018-08-27 18:19:01 +03:00
parent 658b437079
commit 1aa2ab0590
49 changed files with 536 additions and 433 deletions

View File

@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
import time
from typing import List
from typing import List, Union
import numpy as np
import tensorflow as tf
@@ -73,20 +73,87 @@ class Conv2d(object):
class Dense(object):
def __init__(self, params: List):
def __init__(self, params: Union[List, int]):
"""
:param params: list of [num_output_neurons]
"""
self.params = params
self.params = force_list(params)
def __call__(self, input_layer, name: str):
def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None):
"""
returns a tensorflow dense layer
:param input_layer: previous layer
:param name: layer name
:return: dense layer
"""
return tf.layers.dense(input_layer, self.params[0], name=name)
return tf.layers.dense(input_layer, self.params[0], name=name, kernel_initializer=kernel_initializer,
activation=activation)
class NoisyNetDense(object):
"""
A factorized Noisy Net layer
https://arxiv.org/abs/1706.10295.
"""
def __init__(self, params: List):
"""
:param params: list of [num_output_neurons]
"""
self.params = force_list(params)
self.sigma0 = 0.5
def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None):
"""
returns a NoisyNet dense layer
:param input_layer: previous layer
:param name: layer name
:param kernel_initializer: initializer for kernels. Default is to use Gaussian noise that preserves stddev.
:param activation: the activation function
:return: dense layer
"""
#TODO: noise sampling should be externally controlled. DQN is fine with sampling noise for every
# forward (either act or train, both for online and target networks).
# A3C, on the other hand, should sample noise only when policy changes (i.e. after every t_max steps)
num_inputs = input_layer.get_shape()[-1].value
num_outputs = self.params[0]
stddev = 1 / math.sqrt(num_inputs)
activation = activation if activation is not None else (lambda x: x)
if kernel_initializer is None:
kernel_mean_initializer = tf.random_uniform_initializer(-stddev, stddev)
kernel_stddev_initializer = tf.random_uniform_initializer(-stddev * self.sigma0, stddev * self.sigma0)
else:
kernel_mean_initializer = kernel_stddev_initializer = kernel_initializer
with tf.variable_scope(None, default_name=name):
weight_mean = tf.get_variable('weight_mean', shape=(num_inputs, num_outputs),
initializer=kernel_mean_initializer)
bias_mean = tf.get_variable('bias_mean', shape=(num_outputs,), initializer=tf.zeros_initializer())
weight_stddev = tf.get_variable('weight_stddev', shape=(num_inputs, num_outputs),
initializer=kernel_stddev_initializer)
bias_stddev = tf.get_variable('bias_stddev', shape=(num_outputs,),
initializer=kernel_stddev_initializer)
bias_noise = self.f(tf.random_normal((num_outputs,)))
weight_noise = self.factorized_noise(num_inputs, num_outputs)
bias = bias_mean + bias_stddev * bias_noise
weight = weight_mean + weight_stddev * weight_noise
return activation(tf.matmul(input_layer, weight) + bias)
def factorized_noise(self, inputs, outputs):
# TODO: use factorized noise only for compute intensive algos (e.g. DQN).
# lighter algos (e.g. DQN) should not use it
noise1 = self.f(tf.random_normal((inputs, 1)))
noise2 = self.f(tf.random_normal((1, outputs)))
return tf.matmul(noise1, noise2)
@staticmethod
def f(values):
return tf.sqrt(tf.abs(values)) * tf.sign(values)
def variable_summaries(var):

View File

@@ -19,11 +19,40 @@ from typing import List, Union
import numpy as np
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
from rl_coach.core_types import InputEmbedding
class InputEmbedderParameters(NetworkComponentParameters):
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium,
batchnorm: bool=False, dropout=False, name: str='embedder', input_rescaling=None, input_offset=None,
input_clipping=None, dense_layer=Dense):
super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function
self.scheme = scheme
self.batchnorm = batchnorm
self.dropout = dropout
if input_rescaling is None:
input_rescaling = {'image': 255.0, 'vector': 1.0}
if input_offset is None:
input_offset = {'image': 0.0, 'vector': 0.0}
self.input_rescaling = input_rescaling
self.input_offset = input_offset
self.input_clipping = input_clipping
self.name = name
@property
def path(self):
return {
"image": 'image_embedder:ImageEmbedder',
"vector": 'vector_embedder:VectorEmbedder'
}
class InputEmbedder(object):
"""
An input embedder is the first part of the network, which takes the input from the state and produces a vector
@@ -32,7 +61,7 @@ class InputEmbedder(object):
"""
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout: bool=False,
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None):
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense):
self.name = name
self.input_size = input_size
self.activation_function = activation_function
@@ -47,6 +76,7 @@ class InputEmbedder(object):
self.input_rescaling = input_rescaling
self.input_offset = input_offset
self.input_clipping = input_clipping
self.dense_layer = dense_layer
def __call__(self, prev_input_placeholder=None):
with tf.variable_scope(self.get_name()):

View File

@@ -18,7 +18,7 @@ from typing import List
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Conv2d
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.core_types import InputImageEmbedding
@@ -30,45 +30,49 @@ class ImageEmbedder(InputEmbedder):
The embedder is intended for image like inputs, where the channels are expected to be the last axis.
The embedder also allows custom rescaling of the input prior to the neural network.
"""
schemes = {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
Conv2d([32, 3, 1])
],
# atari dqn
EmbedderScheme.Medium:
[
Conv2d([32, 8, 4]),
Conv2d([64, 4, 2]),
Conv2d([64, 3, 1])
],
# carla
EmbedderScheme.Deep: \
[
Conv2d([32, 5, 2]),
Conv2d([32, 3, 1]),
Conv2d([64, 3, 2]),
Conv2d([64, 3, 1]),
Conv2d([128, 3, 2]),
Conv2d([128, 3, 1]),
Conv2d([256, 3, 2]),
Conv2d([256, 3, 1])
]
}
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout: bool=False,
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None):
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout, name, input_rescaling,
input_offset, input_clipping)
input_offset, input_clipping, dense_layer=dense_layer)
self.return_type = InputImageEmbedding
if len(input_size) != 3 and scheme != EmbedderScheme.Empty:
raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}"
.format(input_size))
@property
def schemes(self):
return {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
Conv2d([32, 3, 1])
],
# atari dqn
EmbedderScheme.Medium:
[
Conv2d([32, 8, 4]),
Conv2d([64, 4, 2]),
Conv2d([64, 3, 1])
],
# carla
EmbedderScheme.Deep: \
[
Conv2d([32, 5, 2]),
Conv2d([32, 3, 1]),
Conv2d([64, 3, 2]),
Conv2d([64, 3, 1]),
Conv2d([128, 3, 2]),
Conv2d([128, 3, 1]),
Conv2d([256, 3, 2]),
Conv2d([256, 3, 1])
]
}

View File

@@ -29,36 +29,40 @@ class VectorEmbedder(InputEmbedder):
An input embedder that is intended for inputs that can be represented as vectors.
The embedder flattens the input, applies several dense layers to it and returns the output.
"""
schemes = {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
Dense([128])
],
# dqn
EmbedderScheme.Medium:
[
Dense([256])
],
# carla
EmbedderScheme.Deep: \
[
Dense([128]),
Dense([128]),
Dense([128])
]
}
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout: bool=False,
name: str= "embedder", input_rescaling: float=1.0, input_offset:float=0.0, input_clipping=None):
name: str= "embedder", input_rescaling: float=1.0, input_offset:float=0.0, input_clipping=None,
dense_layer=Dense):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout, name,
input_rescaling, input_offset, input_clipping)
input_rescaling, input_offset, input_clipping, dense_layer=dense_layer)
self.return_type = InputVectorEmbedding
if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty:
raise ValueError("The input size of a vector embedder must contain only a single dimension")
@property
def schemes(self):
return {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
self.dense_layer([128])
],
# dqn
EmbedderScheme.Medium:
[
self.dense_layer([256])
],
# carla
EmbedderScheme.Deep: \
[
self.dense_layer([128]),
self.dense_layer([128]),
self.dense_layer([128])
]
}

View File

@@ -20,10 +20,11 @@ from typing import Dict
import numpy as np
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.middleware import MiddlewareParameters
from rl_coach.base_parameters import AgentParameters, InputEmbedderParameters, EmbeddingMergerType
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.core_types import PredictionType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params

View File

@@ -16,6 +16,8 @@
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
@@ -23,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition
class CategoricalQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='categorical_q_head_params'):
super().__init__(parameterized_class=CategoricalQHead, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='relu', name: str='categorical_q_head_params', dense_layer=Dense):
super().__init__(parameterized_class=CategoricalQHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class CategoricalQHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'categorical_dqn_head'
self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms
@@ -40,7 +45,7 @@ class CategoricalQHead(Head):
self.actions = tf.placeholder(tf.int32, [None], name="actions")
self.input = [self.actions]
values_distribution = tf.layers.dense(input_layer, self.num_actions * self.num_atoms, name='output')
values_distribution = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='output')
values_distribution = tf.reshape(values_distribution, (tf.shape(values_distribution)[0], self.num_actions,
self.num_atoms))
# softmax on atoms dimension

View File

@@ -16,7 +16,7 @@
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
@@ -24,16 +24,19 @@ from rl_coach.spaces import SpacesDefinition
class DDPGActorHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', batchnorm: bool=True):
super().__init__(parameterized_class=DDPGActor, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', batchnorm: bool=True,
dense_layer=Dense):
super().__init__(parameterized_class=DDPGActor, activation_function=activation_function, name=name,
dense_layer=dense_layer)
self.batchnorm = batchnorm
class DDPGActor(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
batchnorm: bool=True):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
batchnorm: bool=True, dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'ddpg_actor_head'
self.return_type = ActionProbabilities
@@ -50,7 +53,7 @@ class DDPGActor(Head):
def _build_module(self, input_layer):
# mean
pre_activation_policy_values_mean = tf.layers.dense(input_layer, self.num_actions, name='fc_mean')
pre_activation_policy_values_mean = self.dense_layer(self.num_actions)(input_layer, name='fc_mean')
policy_values_mean = batchnorm_activation_dropout(pre_activation_policy_values_mean, self.batchnorm,
self.activation_function,
False, 0, 0)[-1]

View File

@@ -15,6 +15,7 @@
#
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
from rl_coach.architectures.tensorflow_components.heads.q_head import QHead
from rl_coach.base_parameters import AgentParameters
@@ -23,14 +24,17 @@ from rl_coach.spaces import SpacesDefinition
class DNDQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params'):
super().__init__(parameterized_class=DNDQHead, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params', dense_layer=Dense):
super().__init__(parameterized_class=DNDQHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class DNDQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'dnd_q_values_head'
self.DND_size = agent_parameters.algorithm.dnd_size
self.DND_key_error_threshold = agent_parameters.algorithm.DND_key_error_threshold

View File

@@ -16,6 +16,7 @@
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
from rl_coach.architectures.tensorflow_components.heads.q_head import QHead
from rl_coach.base_parameters import AgentParameters
@@ -23,27 +24,29 @@ from rl_coach.spaces import SpacesDefinition
class DuelingQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='dueling_q_head_params'):
super().__init__(parameterized_class=DuelingQHead, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='relu', name: str='dueling_q_head_params', dense_layer=Dense):
super().__init__(parameterized_class=DuelingQHead, activation_function=activation_function, name=name, dense_layer=dense_layer)
class DuelingQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'dueling_q_values_head'
def _build_module(self, input_layer):
# state value tower - V
with tf.variable_scope("state_value"):
state_value = tf.layers.dense(input_layer, 512, activation=self.activation_function, name='fc1')
state_value = tf.layers.dense(state_value, 1, name='fc2')
state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
state_value = self.dense_layer(1)(state_value, name='fc2')
# state_value = tf.expand_dims(state_value, axis=-1)
# action advantage tower - A
with tf.variable_scope("action_advantage"):
action_advantage = tf.layers.dense(input_layer, 512, activation=self.activation_function, name='fc1')
action_advantage = tf.layers.dense(action_advantage, self.num_actions, name='fc2')
action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
action_advantage = self.dense_layer(self.num_actions)(action_advantage, name='fc2')
action_advantage = action_advantage - tf.reduce_mean(action_advantage)
# merge to state-action value function Q

View File

@@ -18,8 +18,8 @@ from typing import Type
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.losses.losses_impl import Reduction
from rl_coach.base_parameters import AgentParameters, Parameters
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.base_parameters import AgentParameters, Parameters, NetworkComponentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list
@@ -33,9 +33,10 @@ def normalized_columns_initializer(std=1.0):
return _initializer
class HeadParameters(Parameters):
def __init__(self, parameterized_class: Type['Head'], activation_function: str = 'relu', name: str= 'head'):
super().__init__()
class HeadParameters(NetworkComponentParameters):
def __init__(self, parameterized_class: Type['Head'], activation_function: str = 'relu', name: str= 'head',
dense_layer=Dense):
super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function
self.name = name
self.parameterized_class_name = parameterized_class.__name__
@@ -48,7 +49,8 @@ class Head(object):
an assigned loss function. The heads are algorithm dependent.
"""
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu'):
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
dense_layer=Dense):
self.head_idx = head_idx
self.network_name = network_name
self.network_parameters = agent_parameters.network_wrappers[self.network_name]
@@ -66,6 +68,7 @@ class Head(object):
self.spaces = spaces
self.return_type = None
self.activation_function = activation_function
self.dense_layer = dense_layer
def __call__(self, input_layer):
"""

View File

@@ -16,6 +16,8 @@
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import Measurements
@@ -23,15 +25,18 @@ from rl_coach.spaces import SpacesDefinition
class MeasurementsPredictionHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='measurements_prediction_head_params'):
def __init__(self, activation_function: str ='relu', name: str='measurements_prediction_head_params',
dense_layer=Dense):
super().__init__(parameterized_class=MeasurementsPredictionHead,
activation_function=activation_function, name=name)
activation_function=activation_function, name=name, dense_layer=dense_layer)
class MeasurementsPredictionHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'future_measurements_head'
self.num_actions = len(self.spaces.action.actions)
self.num_measurements = self.spaces.state['measurements'].shape[0]
@@ -43,15 +48,15 @@ class MeasurementsPredictionHead(Head):
# This is almost exactly the same as Dueling Network but we predict the future measurements for each action
# actions expectation tower (expectation stream) - E
with tf.variable_scope("expectation_stream"):
expectation_stream = tf.layers.dense(input_layer, 256, activation=self.activation_function, name='fc1')
expectation_stream = tf.layers.dense(expectation_stream, self.multi_step_measurements_size, name='output')
expectation_stream = self.dense_layer(256)(input_layer, activation=self.activation_function, name='fc1')
expectation_stream = self.dense_layer(self.multi_step_measurements_size)(expectation_stream, name='output')
expectation_stream = tf.expand_dims(expectation_stream, axis=1)
# action fine differences tower (action stream) - A
with tf.variable_scope("action_stream"):
action_stream = tf.layers.dense(input_layer, 256, activation=self.activation_function, name='fc1')
action_stream = tf.layers.dense(action_stream, self.num_actions * self.multi_step_measurements_size,
name='output')
action_stream = self.dense_layer(256)(input_layer, activation=self.activation_function, name='fc1')
action_stream = self.dense_layer(self.num_actions * self.multi_step_measurements_size)(action_stream,
name='output')
action_stream = tf.reshape(action_stream,
(tf.shape(action_stream)[0], self.num_actions, self.multi_step_measurements_size))
action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keepdims=True)

View File

@@ -16,6 +16,7 @@
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
@@ -24,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition
class NAFHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='naf_head_params'):
super().__init__(parameterized_class=NAFHead, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='tanh', name: str='naf_head_params', dense_layer=Dense):
super().__init__(parameterized_class=NAFHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class NAFHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True,activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True,activation_function: str='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
if not isinstance(self.spaces.action, BoxActionSpace):
raise ValueError("NAF works only for continuous action spaces (BoxActionSpace)")
@@ -50,15 +54,15 @@ class NAFHead(Head):
self.input = self.action
# V Head
self.V = tf.layers.dense(input_layer, 1, name='V')
self.V = self.dense_layer(1)(input_layer, name='V')
# mu Head
mu_unscaled = tf.layers.dense(input_layer, self.num_actions, activation=self.activation_function, name='mu_unscaled')
mu_unscaled = self.dense_layer(self.num_actions)(input_layer, activation=self.activation_function, name='mu_unscaled')
self.mu = tf.multiply(mu_unscaled, self.output_scale, name='mu')
# A Head
# l_vector is a vector that includes a lower-triangular matrix values
self.l_vector = tf.layers.dense(input_layer, (self.num_actions * (self.num_actions + 1)) / 2, name='l_vector')
self.l_vector = self.dense_layer((self.num_actions * (self.num_actions + 1)) / 2)(input_layer, name='l_vector')
# Convert l to a lower triangular matrix and exponentiate its diagonal

View File

@@ -17,6 +17,7 @@
import numpy as np
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
@@ -27,14 +28,17 @@ from rl_coach.utils import eps
class PolicyHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params'):
super().__init__(parameterized_class=PolicyHead, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', dense_layer=Dense):
super().__init__(parameterized_class=PolicyHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class PolicyHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'policy_values_head'
self.return_type = ActionProbabilities
self.beta = None
@@ -90,7 +94,7 @@ class PolicyHead(Head):
num_actions = len(action_space.actions)
self.actions.append(tf.placeholder(tf.int32, [None], name="actions"))
policy_values = tf.layers.dense(input_layer, num_actions, name='fc')
policy_values = self.dense_layer(num_actions)(input_layer, name='fc')
self.policy_probs = tf.nn.softmax(policy_values, name="policy")
# define the distributions for the policy and the old policy
@@ -114,7 +118,7 @@ class PolicyHead(Head):
self.continuous_output_activation = None
# mean
pre_activation_policy_values_mean = tf.layers.dense(input_layer, num_actions, name='fc_mean')
pre_activation_policy_values_mean = self.dense_layer(num_actions)(input_layer, name='fc_mean')
policy_values_mean = self.continuous_output_activation(pre_activation_policy_values_mean)
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
@@ -123,8 +127,9 @@ class PolicyHead(Head):
# standard deviation
if isinstance(self.exploration_policy, ContinuousEntropyParameters):
# the stdev is an output of the network and uses a softplus activation as defined in A3C
policy_values_std = tf.layers.dense(input_layer, num_actions,
kernel_initializer=normalized_columns_initializer(0.01), name='fc_std')
policy_values_std = self.dense_layer(num_actions)(input_layer,
kernel_initializer=normalized_columns_initializer(0.01),
name='fc_std')
self.policy_std = tf.nn.softplus(policy_values_std, name='output_variance') + eps
self.output.append(self.policy_std)

View File

@@ -17,6 +17,7 @@
import numpy as np
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters, normalized_columns_initializer
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
@@ -26,14 +27,17 @@ from rl_coach.utils import eps
class PPOHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params'):
super().__init__(parameterized_class=PPOHead, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params', dense_layer=Dense):
super().__init__(parameterized_class=PPOHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class PPOHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'ppo_head'
self.return_type = ActionProbabilities
@@ -110,7 +114,7 @@ class PPOHead(Head):
# Policy Head
self.input = [self.actions, self.old_policy_mean]
policy_values = tf.layers.dense(input_layer, num_actions, name='policy_fc')
policy_values = self.dense_layer(num_actions)(input_layer, name='policy_fc')
self.policy_mean = tf.nn.softmax(policy_values, name="policy")
# define the distributions for the policy and the old policy
@@ -127,7 +131,7 @@ class PPOHead(Head):
self.old_policy_std = tf.placeholder(tf.float32, [None, num_actions], "old_policy_std")
self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
self.policy_mean = tf.layers.dense(input_layer, num_actions, name='policy_mean',
self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean',
kernel_initializer=normalized_columns_initializer(0.01))
if self.is_local:
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',

View File

@@ -16,6 +16,8 @@
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
@@ -23,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition
class PPOVHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='ppo_v_head_params'):
super().__init__(parameterized_class=PPOVHead, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='relu', name: str='ppo_v_head_params', dense_layer=Dense):
super().__init__(parameterized_class=PPOVHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class PPOVHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'ppo_v_head'
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.return_type = ActionProbabilities
@@ -38,7 +43,7 @@ class PPOVHead(Head):
def _build_module(self, input_layer):
self.old_policy_value = tf.placeholder(tf.float32, [None], "old_policy_values")
self.input = [self.old_policy_value]
self.output = tf.layers.dense(input_layer, 1, name='output',
self.output = self.dense_layer(1)(input_layer, name='output',
kernel_initializer=normalized_columns_initializer(1.0))
self.target = self.total_return = tf.placeholder(tf.float32, [None], name="total_return")

View File

@@ -16,6 +16,8 @@
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
@@ -23,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpac
class QHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='q_head_params'):
super().__init__(parameterized_class=QHead, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='relu', name: str='q_head_params', dense_layer=Dense):
super().__init__(parameterized_class=QHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class QHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'q_values_head'
if isinstance(self.spaces.action, BoxActionSpace):
self.num_actions = 1
@@ -44,7 +49,7 @@ class QHead(Head):
def _build_module(self, input_layer):
# Standard Q Network
self.output = tf.layers.dense(input_layer, self.num_actions, name='output')
self.output = self.dense_layer(self.num_actions)(input_layer, name='output')

View File

@@ -16,6 +16,8 @@
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
@@ -23,15 +25,18 @@ from rl_coach.spaces import SpacesDefinition
class QuantileRegressionQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='quantile_regression_q_head_params'):
def __init__(self, activation_function: str ='relu', name: str='quantile_regression_q_head_params',
dense_layer=Dense):
super().__init__(parameterized_class=QuantileRegressionQHead, activation_function=activation_function,
name=name)
name=name, dense_layer=dense_layer)
class QuantileRegressionQHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'quantile_regression_dqn_head'
self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms # we use atom / quantile interchangeably
@@ -44,7 +49,7 @@ class QuantileRegressionQHead(Head):
self.input = [self.actions, self.quantile_midpoints]
# the output of the head is the N unordered quantile locations {theta_1, ..., theta_N}
quantiles_locations = tf.layers.dense(input_layer, self.num_actions * self.num_atoms, name='output')
quantiles_locations = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='output')
quantiles_locations = tf.reshape(quantiles_locations, (tf.shape(quantiles_locations)[0], self.num_actions, self.num_atoms))
self.output = quantiles_locations

View File

@@ -16,6 +16,8 @@
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import VStateValue
@@ -23,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition
class VHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='v_head_params'):
super().__init__(parameterized_class=VHead, activation_function=activation_function, name=name)
def __init__(self, activation_function: str ='relu', name: str='v_head_params', dense_layer=Dense):
super().__init__(parameterized_class=VHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class VHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'v_values_head'
self.return_type = VStateValue
@@ -41,5 +46,5 @@ class VHead(Head):
def _build_module(self, input_layer):
# Standard V Network
self.output = tf.layers.dense(input_layer, 1, name='output',
kernel_initializer=normalized_columns_initializer(1.0))
self.output = self.dense_layer(1)(input_layer, name='output',
kernel_initializer=normalized_columns_initializer(1.0))

View File

@@ -27,42 +27,18 @@ class FCMiddlewareParameters(MiddlewareParameters):
def __init__(self, activation_function='relu',
scheme: Union[List, MiddlewareScheme] = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False,
name="middleware_fc_embedder"):
name="middleware_fc_embedder", dense_layer=Dense):
super().__init__(parameterized_class=FCMiddleware, activation_function=activation_function,
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name)
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name, dense_layer=dense_layer)
class FCMiddleware(Middleware):
schemes = {
MiddlewareScheme.Empty:
[],
# ppo
MiddlewareScheme.Shallow:
[
Dense([64])
],
# dqn
MiddlewareScheme.Medium:
[
Dense([512])
],
MiddlewareScheme.Deep: \
[
Dense([128]),
Dense([128]),
Dense([128])
]
}
def __init__(self, activation_function=tf.nn.relu,
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False,
name="middleware_fc_embedder"):
name="middleware_fc_embedder", dense_layer=Dense):
super().__init__(activation_function=activation_function, batchnorm=batchnorm,
dropout=dropout, scheme=scheme, name=name)
dropout=dropout, scheme=scheme, name=name, dense_layer=dense_layer)
self.return_type = Middleware_FC_Embedding
self.layers = []
@@ -70,7 +46,7 @@ class FCMiddleware(Middleware):
self.layers.append(self.input)
if isinstance(self.scheme, MiddlewareScheme):
layers_params = FCMiddleware.schemes[self.scheme]
layers_params = self.schemes[self.scheme]
else:
layers_params = self.scheme
for idx, layer_params in enumerate(layers_params):
@@ -84,3 +60,29 @@ class FCMiddleware(Middleware):
self.output = self.layers[-1]
@property
def schemes(self):
return {
MiddlewareScheme.Empty:
[],
# ppo
MiddlewareScheme.Shallow:
[
self.dense_layer([64])
],
# dqn
MiddlewareScheme.Medium:
[
self.dense_layer([512])
],
MiddlewareScheme.Deep: \
[
self.dense_layer([128]),
self.dense_layer([128]),
self.dense_layer([128])
]
}

View File

@@ -18,7 +18,7 @@
import numpy as np
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware, MiddlewareParameters
from rl_coach.base_parameters import MiddlewareScheme
from rl_coach.core_types import Middleware_LSTM_Embedding
@@ -28,43 +28,19 @@ class LSTMMiddlewareParameters(MiddlewareParameters):
def __init__(self, activation_function='relu', number_of_lstm_cells=256,
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False,
name="middleware_lstm_embedder"):
name="middleware_lstm_embedder", dense_layer=Dense):
super().__init__(parameterized_class=LSTMMiddleware, activation_function=activation_function,
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name)
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name, dense_layer=dense_layer)
self.number_of_lstm_cells = number_of_lstm_cells
class LSTMMiddleware(Middleware):
schemes = {
MiddlewareScheme.Empty:
[],
# ppo
MiddlewareScheme.Shallow:
[
[64]
],
# dqn
MiddlewareScheme.Medium:
[
[512]
],
MiddlewareScheme.Deep: \
[
[128],
[128],
[128]
]
}
def __init__(self, activation_function=tf.nn.relu, number_of_lstm_cells: int=256,
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False,
name="middleware_lstm_embedder"):
name="middleware_lstm_embedder", dense_layer=Dense):
super().__init__(activation_function=activation_function, batchnorm=batchnorm,
dropout=dropout, scheme=scheme, name=name)
dropout=dropout, scheme=scheme, name=name, dense_layer=dense_layer)
self.return_type = Middleware_LSTM_Embedding
self.number_of_lstm_cells = number_of_lstm_cells
self.layers = []
@@ -83,7 +59,7 @@ class LSTMMiddleware(Middleware):
# optionally insert some dense layers before the LSTM
if isinstance(self.scheme, MiddlewareScheme):
layers_params = LSTMMiddleware.schemes[self.scheme]
layers_params = self.schemes[self.scheme]
else:
layers_params = self.scheme
for idx, layer_params in enumerate(layers_params):
@@ -111,3 +87,30 @@ class LSTMMiddleware(Middleware):
lstm_c, lstm_h = lstm_state
self.state_out = (lstm_c[:1, :], lstm_h[:1, :])
self.output = tf.reshape(lstm_outputs, [-1, self.number_of_lstm_cells])
@property
def schemes(self):
return {
MiddlewareScheme.Empty:
[],
# ppo
MiddlewareScheme.Shallow:
[
[64]
],
# dqn
MiddlewareScheme.Medium:
[
[512]
],
MiddlewareScheme.Deep: \
[
[128],
[128],
[128]
]
}

View File

@@ -17,16 +17,16 @@ from typing import Type, Union, List
import tensorflow as tf
from rl_coach.base_parameters import MiddlewareScheme, Parameters
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.base_parameters import MiddlewareScheme, Parameters, NetworkComponentParameters
from rl_coach.core_types import MiddlewareEmbedding
class MiddlewareParameters(Parameters):
class MiddlewareParameters(NetworkComponentParameters):
def __init__(self, parameterized_class: Type['Middleware'],
activation_function: str='relu', scheme: Union[List, MiddlewareScheme]=MiddlewareScheme.Medium,
batchnorm: bool=False, dropout: bool=False,
name='middleware'):
super().__init__()
batchnorm: bool=False, dropout: bool=False, name='middleware', dense_layer=Dense):
super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function
self.scheme = scheme
self.batchnorm = batchnorm
@@ -43,7 +43,7 @@ class Middleware(object):
"""
def __init__(self, activation_function=tf.nn.relu,
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False, name="middleware_embedder"):
batchnorm: bool = False, dropout: bool = False, name="middleware_embedder", dense_layer=Dense):
self.name = name
self.input = None
self.output = None
@@ -53,6 +53,7 @@ class Middleware(object):
self.dropout_rate = 0
self.scheme = scheme
self.return_type = MiddlewareEmbedding
self.dense_layer = dense_layer
def __call__(self, input_layer):
with tf.variable_scope(self.get_name()):
@@ -66,3 +67,8 @@ class Middleware(object):
def get_name(self):
return self.name
@property
def schemes(self):
raise NotImplementedError("Inheriting middleware must define schemes matching its allowed default "
"configurations.")