1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

pre-release 0.10.0

This commit is contained in:
Gal Novik
2018-08-13 17:11:34 +03:00
parent d44c329bb8
commit 19ca5c24b1
485 changed files with 33292 additions and 16770 deletions

View File

@@ -0,0 +1,54 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.core_types import QActionStateValue
class CategoricalQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='categorical_q_head_params'):
super().__init__(parameterized_class=CategoricalQHead, activation_function=activation_function, name=name)
class CategoricalQHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'categorical_dqn_head'
self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms
self.return_type = QActionStateValue
def _build_module(self, input_layer):
self.actions = tf.placeholder(tf.int32, [None], name="actions")
self.input = [self.actions]
values_distribution = tf.layers.dense(input_layer, self.num_actions * self.num_atoms, name='output')
values_distribution = tf.reshape(values_distribution, (tf.shape(values_distribution)[0], self.num_actions,
self.num_atoms))
# softmax on atoms dimension
self.output = tf.nn.softmax(values_distribution)
# calculate cross entropy loss
self.distributions = tf.placeholder(tf.float32, shape=(None, self.num_actions, self.num_atoms),
name="distributions")
self.target = self.distributions
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
tf.losses.add_loss(self.loss)

View File

@@ -0,0 +1,66 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.core_types import ActionProbabilities
class DDPGActorHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', batchnorm: bool=True):
super().__init__(parameterized_class=DDPGActor, activation_function=activation_function, name=name)
self.batchnorm = batchnorm
class DDPGActor(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
batchnorm: bool=True):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'ddpg_actor_head'
self.return_type = ActionProbabilities
self.num_actions = self.spaces.action.shape
self.batchnorm = batchnorm
# bounded actions
self.output_scale = self.spaces.action.max_abs_range
# a scalar weight that penalizes high activation values (before the activation function) for the final layer
if hasattr(agent_parameters.algorithm, 'action_penalty'):
self.action_penalty = agent_parameters.algorithm.action_penalty
def _build_module(self, input_layer):
# mean
pre_activation_policy_values_mean = tf.layers.dense(input_layer, self.num_actions, name='fc_mean')
policy_values_mean = batchnorm_activation_dropout(pre_activation_policy_values_mean, self.batchnorm,
self.activation_function,
False, 0, 0)[-1]
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
if self.is_local:
# add a squared penalty on the squared pre-activation features of the action
if self.action_penalty and self.action_penalty != 0:
self.regularizations += \
[self.action_penalty * tf.reduce_mean(tf.square(pre_activation_policy_values_mean))]
self.output = [self.policy_mean]

View File

@@ -0,0 +1,87 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.architectures.tensorflow_components.heads.q_head import QHead
from rl_coach.spaces import SpacesDefinition
from rl_coach.memories.non_episodic import differentiable_neural_dictionary
class DNDQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params'):
super().__init__(parameterized_class=DNDQHead, activation_function=activation_function, name=name)
class DNDQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'dnd_q_values_head'
self.DND_size = agent_parameters.algorithm.dnd_size
self.DND_key_error_threshold = agent_parameters.algorithm.DND_key_error_threshold
self.l2_norm_added_delta = agent_parameters.algorithm.l2_norm_added_delta
self.new_value_shift_coefficient = agent_parameters.algorithm.new_value_shift_coefficient
self.number_of_nn = agent_parameters.algorithm.number_of_knn
self.ap = agent_parameters
self.dnd_embeddings = [None] * self.num_actions
self.dnd_values = [None] * self.num_actions
self.dnd_indices = [None] * self.num_actions
self.dnd_distances = [None] * self.num_actions
if self.ap.memory.shared_memory:
self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad
def _build_module(self, input_layer):
if hasattr(self.ap.task_parameters, 'checkpoint_restore_dir') and self.ap.task_parameters.checkpoint_restore_dir:
self.DND = differentiable_neural_dictionary.load_dnd(self.ap.task_parameters.checkpoint_restore_dir)
else:
self.DND = differentiable_neural_dictionary.QDND(
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
key_error_threshold=self.DND_key_error_threshold,
learning_rate=self.network_parameters.learning_rate,
num_neighbors=self.number_of_nn,
override_existing_keys=True)
# Retrieve info from DND dictionary
# We assume that all actions have enough entries in the DND
self.output = tf.transpose([
self._q_value(input_layer, action)
for action in range(self.num_actions)
])
def _q_value(self, input_layer, action):
result = tf.py_func(self.DND.query,
[input_layer, action, self.number_of_nn],
[tf.float64, tf.float64, tf.int64])
self.dnd_embeddings[action] = tf.to_float(result[0])
self.dnd_values[action] = tf.to_float(result[1])
self.dnd_indices[action] = result[2]
# DND calculation
square_diff = tf.square(self.dnd_embeddings[action] - tf.expand_dims(input_layer, 1))
distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta]
self.dnd_distances[action] = distances
weights = 1.0 / distances
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True)
q_value = tf.reduce_sum(self.dnd_values[action] * normalised_weights, axis=1)
q_value.set_shape((None,))
return q_value
def _post_build(self):
# DND gradients
self.dnd_embeddings_grad = tf.gradients(self.loss[0], self.dnd_embeddings)
self.dnd_values_grad = tf.gradients(self.loss[0], self.dnd_values)

View File

@@ -0,0 +1,50 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.architectures.tensorflow_components.heads.q_head import QHead
from rl_coach.spaces import SpacesDefinition
class DuelingQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='dueling_q_head_params'):
super().__init__(parameterized_class=DuelingQHead, activation_function=activation_function, name=name)
class DuelingQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'dueling_q_values_head'
def _build_module(self, input_layer):
# state value tower - V
with tf.variable_scope("state_value"):
state_value = tf.layers.dense(input_layer, 512, activation=self.activation_function, name='fc1')
state_value = tf.layers.dense(state_value, 1, name='fc2')
# state_value = tf.expand_dims(state_value, axis=-1)
# action advantage tower - A
with tf.variable_scope("action_advantage"):
action_advantage = tf.layers.dense(input_layer, 512, activation=self.activation_function, name='fc1')
action_advantage = tf.layers.dense(action_advantage, self.num_actions, name='fc2')
action_advantage = action_advantage - tf.reduce_mean(action_advantage)
# merge to state-action value function Q
self.output = tf.add(state_value, action_advantage, name='output')

View File

@@ -0,0 +1,165 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Type
import numpy as np
import tensorflow as tf
from rl_coach.base_parameters import AgentParameters, Parameters
from rl_coach.spaces import SpacesDefinition
from tensorflow.python.ops.losses.losses_impl import Reduction
from rl_coach.utils import force_list
# Used to initialize weights for policy and value output layers
def normalized_columns_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
class HeadParameters(Parameters):
def __init__(self, parameterized_class: Type['Head'], activation_function: str = 'relu', name: str= 'head'):
super().__init__()
self.activation_function = activation_function
self.name = name
self.parameterized_class_name = parameterized_class.__name__
class Head(object):
"""
A head is the final part of the network. It takes the embedding from the middleware embedder and passes it through
a neural network to produce the output of the network. There can be multiple heads in a network, and each one has
an assigned loss function. The heads are algorithm dependent.
"""
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu'):
self.head_idx = head_idx
self.network_name = network_name
self.network_parameters = agent_parameters.network_wrappers[self.network_name]
self.name = "head"
self.output = []
self.loss = []
self.loss_type = []
self.regularizations = []
self.loss_weight = force_list(loss_weight)
self.target = []
self.importance_weight = []
self.input = []
self.is_local = is_local
self.ap = agent_parameters
self.spaces = spaces
self.return_type = None
self.activation_function = activation_function
def __call__(self, input_layer):
"""
Wrapper for building the module graph including scoping and loss creation
:param input_layer: the input to the graph
:return: the output of the last layer and the target placeholder
"""
with tf.variable_scope(self.get_name(), initializer=tf.contrib.layers.xavier_initializer()):
self._build_module(input_layer)
self.output = force_list(self.output)
self.target = force_list(self.target)
self.input = force_list(self.input)
self.loss_type = force_list(self.loss_type)
self.loss = force_list(self.loss)
self.regularizations = force_list(self.regularizations)
if self.is_local:
self.set_loss()
self._post_build()
if self.is_local:
return self.output, self.target, self.input, self.importance_weight
else:
return self.output, self.input
def _build_module(self, input_layer):
"""
Builds the graph of the module
This method is called early on from __call__. It is expected to store the graph
in self.output.
:param input_layer: the input to the graph
:return: None
"""
pass
def _post_build(self):
"""
Optional function that allows adding any extra definitions after the head has been fully defined
For example, this allows doing additional calculations that are based on the loss
:return: None
"""
pass
def get_name(self):
"""
Get a formatted name for the module
:return: the formatted name
"""
return '{}_{}'.format(self.name, self.head_idx)
def set_loss(self):
"""
Creates a target placeholder and loss function for each loss_type and regularization
:param loss_type: a tensorflow loss function
:param scope: the name scope to include the tensors in
:return: None
"""
# there are heads that define the loss internally, but we need to create additional placeholders for them
for idx in range(len(self.loss)):
importance_weight = tf.placeholder('float',
[None] + [1] * (len(self.target[idx].shape) - 1),
'{}_importance_weight'.format(self.get_name()))
self.importance_weight.append(importance_weight)
# add losses and target placeholder
for idx in range(len(self.loss_type)):
# create target placeholder
target = tf.placeholder('float', self.output[idx].shape, '{}_target'.format(self.get_name()))
self.target.append(target)
# create importance sampling weights placeholder
num_target_dims = len(self.target[idx].shape)
importance_weight = tf.placeholder('float', [None] + [1] * (num_target_dims - 1),
'{}_importance_weight'.format(self.get_name()))
self.importance_weight.append(importance_weight)
# compute the weighted loss. importance_weight weights over the samples in the batch, while self.loss_weight
# weights the specific loss of this head against other losses in this head or in other heads
loss_weight = self.loss_weight[idx]*importance_weight
loss = self.loss_type[idx](self.target[-1], self.output[idx],
scope=self.get_name(), reduction=Reduction.NONE, loss_collection=None)
# the loss is first summed over each sample in the batch and then the mean over the batch is taken
loss = tf.reduce_mean(loss_weight*tf.reduce_sum(loss, axis=list(range(1, num_target_dims))))
# we add the loss to the losses collection and later we will extract it in general_network
tf.losses.add_loss(loss)
self.loss.append(loss)
# add regularizations
for regularization in self.regularizations:
self.loss.append(regularization)
@classmethod
def path(cls):
return cls.__class__.__name__

View File

@@ -0,0 +1,65 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.core_types import Measurements
class MeasurementsPredictionHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='measurements_prediction_head_params'):
super().__init__(parameterized_class=MeasurementsPredictionHead,
activation_function=activation_function, name=name)
class MeasurementsPredictionHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'future_measurements_head'
self.num_actions = len(self.spaces.action.actions)
self.num_measurements = self.spaces.state['measurements'].shape[0]
self.num_prediction_steps = agent_parameters.algorithm.num_predicted_steps_ahead
self.multi_step_measurements_size = self.num_measurements * self.num_prediction_steps
self.return_type = Measurements
def _build_module(self, input_layer):
# This is almost exactly the same as Dueling Network but we predict the future measurements for each action
# actions expectation tower (expectation stream) - E
with tf.variable_scope("expectation_stream"):
expectation_stream = tf.layers.dense(input_layer, 256, activation=self.activation_function, name='fc1')
expectation_stream = tf.layers.dense(expectation_stream, self.multi_step_measurements_size, name='output')
expectation_stream = tf.expand_dims(expectation_stream, axis=1)
# action fine differences tower (action stream) - A
with tf.variable_scope("action_stream"):
action_stream = tf.layers.dense(input_layer, 256, activation=self.activation_function, name='fc1')
action_stream = tf.layers.dense(action_stream, self.num_actions * self.multi_step_measurements_size,
name='output')
action_stream = tf.reshape(action_stream,
(tf.shape(action_stream)[0], self.num_actions, self.multi_step_measurements_size))
action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keepdims=True)
# merge to future measurements predictions
self.output = tf.add(expectation_stream, action_stream, name='output')
self.target = tf.placeholder(tf.float32, [None, self.num_actions, self.multi_step_measurements_size],
name="targets")
targets_nonan = tf.where(tf.is_nan(self.target), self.output, self.target)
self.loss = tf.reduce_sum(tf.reduce_mean(tf.square(targets_nonan - self.output), reduction_indices=0))
tf.losses.add_loss(self.loss_weight[0] * self.loss)

View File

@@ -0,0 +1,88 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import BoxActionSpace
from rl_coach.spaces import SpacesDefinition
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.core_types import QActionStateValue
class NAFHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='naf_head_params'):
super().__init__(parameterized_class=NAFHead, activation_function=activation_function, name=name)
class NAFHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True,activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
if not isinstance(self.spaces.action, BoxActionSpace):
raise ValueError("NAF works only for continuous action spaces (BoxActionSpace)")
self.name = 'naf_q_values_head'
self.num_actions = self.spaces.action.shape[0]
self.output_scale = self.spaces.action.max_abs_range
self.return_type = QActionStateValue
if agent_parameters.network_wrappers[self.network_name].replace_mse_with_huber_loss:
self.loss_type = tf.losses.huber_loss
else:
self.loss_type = tf.losses.mean_squared_error
def _build_module(self, input_layer):
# NAF
self.action = tf.placeholder(tf.float32, [None, self.num_actions], name="action")
self.input = self.action
# V Head
self.V = tf.layers.dense(input_layer, 1, name='V')
# mu Head
mu_unscaled = tf.layers.dense(input_layer, self.num_actions, activation=self.activation_function, name='mu_unscaled')
self.mu = tf.multiply(mu_unscaled, self.output_scale, name='mu')
# A Head
# l_vector is a vector that includes a lower-triangular matrix values
self.l_vector = tf.layers.dense(input_layer, (self.num_actions * (self.num_actions + 1)) / 2, name='l_vector')
# Convert l to a lower triangular matrix and exponentiate its diagonal
i = 0
columns = []
for col in range(self.num_actions):
start_row = col
num_non_zero_elements = self.num_actions - start_row
zeros_column_part = tf.zeros_like(self.l_vector[:, 0:start_row])
diag_element = tf.expand_dims(tf.exp(self.l_vector[:, i]), 1)
non_zeros_non_diag_column_part = self.l_vector[:, (i + 1):(i + num_non_zero_elements)]
columns.append(tf.concat([zeros_column_part, diag_element, non_zeros_non_diag_column_part], axis=1))
i += num_non_zero_elements
self.L = tf.transpose(tf.stack(columns, axis=1), (0, 2, 1))
# P = L*L^T
self.P = tf.matmul(self.L, tf.transpose(self.L, (0, 2, 1)))
# A = -1/2 * (u - mu)^T * P * (u - mu)
action_diff = tf.expand_dims(self.action - self.mu, -1)
a_matrix_form = -0.5 * tf.matmul(tf.transpose(action_diff, (0, 2, 1)), tf.matmul(self.P, action_diff))
self.A = tf.reshape(a_matrix_form, [-1, 1])
# Q Head
self.Q = tf.add(self.V, self.A, name='Q')
self.output = self.Q

View File

@@ -0,0 +1,151 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, CompoundActionSpace
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import eps
from rl_coach.core_types import ActionProbabilities
from rl_coach.exploration_policies.continuous_entropy import ContinuousEntropyParameters
class PolicyHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params'):
super().__init__(parameterized_class=PolicyHead, activation_function=activation_function, name=name)
class PolicyHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'policy_values_head'
self.return_type = ActionProbabilities
self.beta = None
self.action_penalty = None
self.exploration_policy = agent_parameters.exploration
# a scalar weight that penalizes low entropy values to encourage exploration
if hasattr(agent_parameters.algorithm, 'beta_entropy'):
self.beta = agent_parameters.algorithm.beta_entropy
# a scalar weight that penalizes high activation values (before the activation function) for the final layer
if hasattr(agent_parameters.algorithm, 'action_penalty'):
self.action_penalty = agent_parameters.algorithm.action_penalty
def _build_module(self, input_layer):
self.actions = []
self.input = self.actions
self.policy_distributions = []
self.output = []
action_spaces = [self.spaces.action]
if isinstance(self.spaces.action, CompoundActionSpace):
action_spaces = self.spaces.action.sub_action_spaces
# create a compound action network
for action_space_idx, action_space in enumerate(action_spaces):
with tf.variable_scope("sub_action_{}".format(action_space_idx)):
if isinstance(action_space, DiscreteActionSpace):
# create a discrete action network (softmax probabilities output)
self._build_discrete_net(input_layer, action_space)
elif isinstance(action_space, BoxActionSpace):
# create a continuous action network (bounded mean and stdev outputs)
self._build_continuous_net(input_layer, action_space)
if self.is_local:
# add entropy regularization
if self.beta:
self.entropy = tf.add_n([tf.reduce_mean(dist.entropy()) for dist in self.policy_distributions])
self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')]
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
# calculate loss
self.action_log_probs_wrt_policy = \
tf.add_n([dist.log_prob(action) for dist, action in zip(self.policy_distributions, self.actions)])
self.advantages = tf.placeholder(tf.float32, [None], name="advantages")
self.target = self.advantages
self.loss = -tf.reduce_mean(self.action_log_probs_wrt_policy * self.advantages)
tf.losses.add_loss(self.loss_weight[0] * self.loss)
def _build_discrete_net(self, input_layer, action_space):
num_actions = len(action_space.actions)
self.actions.append(tf.placeholder(tf.int32, [None], name="actions"))
policy_values = tf.layers.dense(input_layer, num_actions, name='fc')
self.policy_probs = tf.nn.softmax(policy_values, name="policy")
# define the distributions for the policy and the old policy
# (the + eps is to prevent probability 0 which will cause the log later on to be -inf)
policy_distribution = tf.contrib.distributions.Categorical(probs=(self.policy_probs + eps))
self.policy_distributions.append(policy_distribution)
self.output.append(self.policy_probs)
def _build_continuous_net(self, input_layer, action_space):
num_actions = action_space.shape
self.actions.append(tf.placeholder(tf.float32, [None, num_actions], name="actions"))
# output activation function
if np.all(self.spaces.action.max_abs_range < np.inf):
# bounded actions
self.output_scale = action_space.max_abs_range
self.continuous_output_activation = self.activation_function
else:
# unbounded actions
self.output_scale = 1
self.continuous_output_activation = None
# mean
pre_activation_policy_values_mean = tf.layers.dense(input_layer, num_actions, name='fc_mean')
policy_values_mean = self.continuous_output_activation(pre_activation_policy_values_mean)
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
self.output.append(self.policy_mean)
# standard deviation
if isinstance(self.exploration_policy, ContinuousEntropyParameters):
# the stdev is an output of the network and uses a softplus activation as defined in A3C
policy_values_std = tf.layers.dense(input_layer, num_actions,
kernel_initializer=normalized_columns_initializer(0.01), name='fc_std')
self.policy_std = tf.nn.softplus(policy_values_std, name='output_variance') + eps
self.output.append(self.policy_std)
else:
# the stdev is an externally given value
# Warning: we need to explicitly put this variable in the local variables collections, since defining
# it as not trainable puts it for some reason in the global variables collections. If this is not done,
# the variable won't be initialized and when working with multiple workers they will get stuck.
self.policy_std = tf.Variable(np.ones(num_actions), dtype='float32', trainable=False,
name='policy_stdev', collections=[tf.GraphKeys.LOCAL_VARIABLES])
# assign op for the policy std
self.policy_std_placeholder = tf.placeholder('float32', (num_actions,))
self.assign_policy_std = tf.assign(self.policy_std, self.policy_std_placeholder)
# define the distributions for the policy and the old policy
policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.policy_mean, self.policy_std)
self.policy_distributions.append(policy_distribution)
if self.is_local:
# add a squared penalty on the squared pre-activation features of the action
if self.action_penalty and self.action_penalty != 0:
self.regularizations += [
self.action_penalty * tf.reduce_mean(tf.square(pre_activation_policy_values_mean))]

View File

@@ -0,0 +1,144 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import tensorflow as tf
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import BoxActionSpace, DiscreteActionSpace
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import eps
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters, normalized_columns_initializer
from rl_coach.core_types import ActionProbabilities
class PPOHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params'):
super().__init__(parameterized_class=PPOHead, activation_function=activation_function, name=name)
class PPOHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'ppo_head'
self.return_type = ActionProbabilities
# used in regular PPO
self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization
if self.use_kl_regularization:
# kl coefficient and its corresponding assignment operation and placeholder
self.kl_coefficient = tf.Variable(agent_parameters.algorithm.initial_kl_coefficient,
trainable=False, name='kl_coefficient')
self.kl_coefficient_ph = tf.placeholder('float', name='kl_coefficient_ph')
self.assign_kl_coefficient = tf.assign(self.kl_coefficient, self.kl_coefficient_ph)
self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence
self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.beta = agent_parameters.algorithm.beta_entropy
def _build_module(self, input_layer):
if isinstance(self.spaces.action, DiscreteActionSpace):
self._build_discrete_net(input_layer, self.spaces.action)
elif isinstance(self.spaces.action, BoxActionSpace):
self._build_continuous_net(input_layer, self.spaces.action)
else:
raise ValueError("only discrete or continuous action spaces are supported for PPO")
self.action_probs_wrt_policy = self.policy_distribution.log_prob(self.actions)
self.action_probs_wrt_old_policy = self.old_policy_distribution.log_prob(self.actions)
self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
# Used by regular PPO only
# add kl divergence regularization
self.kl_divergence = tf.reduce_mean(tf.distributions.kl_divergence(self.old_policy_distribution, self.policy_distribution))
if self.use_kl_regularization:
# no clipping => use kl regularization
self.weighted_kl_divergence = tf.multiply(self.kl_coefficient, self.kl_divergence)
self.regularizations = self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \
tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff))
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
# calculate surrogate loss
self.advantages = tf.placeholder(tf.float32, [None], name="advantages")
self.target = self.advantages
# action_probs_wrt_old_policy != 0 because it is e^...
self.likelihood_ratio = tf.exp(self.action_probs_wrt_policy - self.action_probs_wrt_old_policy)
if self.clip_likelihood_ratio_using_epsilon is not None:
self.clip_param_rescaler = tf.placeholder(tf.float32, ())
self.input.append(self.clip_param_rescaler)
max_value = 1 + self.clip_likelihood_ratio_using_epsilon * self.clip_param_rescaler
min_value = 1 - self.clip_likelihood_ratio_using_epsilon * self.clip_param_rescaler
self.clipped_likelihood_ratio = tf.clip_by_value(self.likelihood_ratio, min_value, max_value)
self.scaled_advantages = tf.minimum(self.likelihood_ratio * self.advantages,
self.clipped_likelihood_ratio * self.advantages)
else:
self.scaled_advantages = self.likelihood_ratio * self.advantages
# minus sign is in order to set an objective to minimize (we actually strive for maximizing the surrogate loss)
self.surrogate_loss = -tf.reduce_mean(self.scaled_advantages)
if self.is_local:
# add entropy regularization
if self.beta:
self.entropy = tf.reduce_mean(self.policy_distribution.entropy())
self.regularizations = -tf.multiply(self.beta, self.entropy, name='entropy_regularization')
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.regularizations)
self.loss = self.surrogate_loss
tf.losses.add_loss(self.loss)
def _build_discrete_net(self, input_layer, action_space):
num_actions = len(action_space.actions)
self.actions = tf.placeholder(tf.int32, [None], name="actions")
self.old_policy_mean = tf.placeholder(tf.float32, [None, num_actions], "old_policy_mean")
self.old_policy_std = tf.placeholder(tf.float32, [None, num_actions], "old_policy_std")
# Policy Head
self.input = [self.actions, self.old_policy_mean]
policy_values = tf.layers.dense(input_layer, num_actions, name='policy_fc')
self.policy_mean = tf.nn.softmax(policy_values, name="policy")
# define the distributions for the policy and the old policy
self.policy_distribution = tf.contrib.distributions.Categorical(probs=self.policy_mean)
self.old_policy_distribution = tf.contrib.distributions.Categorical(probs=self.old_policy_mean)
self.output = self.policy_mean
def _build_continuous_net(self, input_layer, action_space):
num_actions = action_space.shape[0]
self.actions = tf.placeholder(tf.float32, [None, num_actions], name="actions")
self.old_policy_mean = tf.placeholder(tf.float32, [None, num_actions], "old_policy_mean")
self.old_policy_std = tf.placeholder(tf.float32, [None, num_actions], "old_policy_std")
self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
self.policy_mean = tf.layers.dense(input_layer, num_actions, name='policy_mean',
kernel_initializer=normalized_columns_initializer(0.01))
if self.is_local:
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',
collections=[tf.GraphKeys.LOCAL_VARIABLES])
else:
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32')
self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std')
# define the distributions for the policy and the old policy
self.policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.policy_mean, self.policy_std + eps)
self.old_policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.old_policy_mean, self.old_policy_std + eps)
self.output = [self.policy_mean, self.policy_std]

View File

@@ -0,0 +1,52 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
from rl_coach.core_types import ActionProbabilities
class PPOVHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='ppo_v_head_params'):
super().__init__(parameterized_class=PPOVHead, activation_function=activation_function, name=name)
class PPOVHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'ppo_v_head'
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.return_type = ActionProbabilities
def _build_module(self, input_layer):
self.old_policy_value = tf.placeholder(tf.float32, [None], "old_policy_values")
self.input = [self.old_policy_value]
self.output = tf.layers.dense(input_layer, 1, name='output',
kernel_initializer=normalized_columns_initializer(1.0))
self.target = self.total_return = tf.placeholder(tf.float32, [None], name="total_return")
value_loss_1 = tf.square(self.output - self.target)
value_loss_2 = tf.square(self.old_policy_value +
tf.clip_by_value(self.output - self.old_policy_value,
-self.clip_likelihood_ratio_using_epsilon,
self.clip_likelihood_ratio_using_epsilon) - self.target)
self.vf_loss = tf.reduce_mean(tf.maximum(value_loss_1, value_loss_2))
self.loss = self.vf_loss
tf.losses.add_loss(self.loss)

View File

@@ -0,0 +1,50 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.core_types import QActionStateValue
class QHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='q_head_params'):
super().__init__(parameterized_class=QHead, activation_function=activation_function, name=name)
class QHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'q_values_head'
if isinstance(self.spaces.action, BoxActionSpace):
self.num_actions = 1
elif isinstance(self.spaces.action, DiscreteActionSpace):
self.num_actions = len(self.spaces.action.actions)
self.return_type = QActionStateValue
if agent_parameters.network_wrappers[self.network_name].replace_mse_with_huber_loss:
self.loss_type = tf.losses.huber_loss
else:
self.loss_type = tf.losses.mean_squared_error
def _build_module(self, input_layer):
# Standard Q Network
self.output = tf.layers.dense(input_layer, self.num_actions, name='output')

View File

@@ -0,0 +1,76 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.core_types import QActionStateValue
class QuantileRegressionQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='quantile_regression_q_head_params'):
super().__init__(parameterized_class=QuantileRegressionQHead, activation_function=activation_function,
name=name)
class QuantileRegressionQHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'quantile_regression_dqn_head'
self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms # we use atom / quantile interchangeably
self.huber_loss_interval = agent_parameters.algorithm.huber_loss_interval # k
self.return_type = QActionStateValue
def _build_module(self, input_layer):
self.actions = tf.placeholder(tf.int32, [None, 2], name="actions")
self.quantile_midpoints = tf.placeholder(tf.float32, [None, self.num_atoms], name="quantile_midpoints")
self.input = [self.actions, self.quantile_midpoints]
# the output of the head is the N unordered quantile locations {theta_1, ..., theta_N}
quantiles_locations = tf.layers.dense(input_layer, self.num_actions * self.num_atoms, name='output')
quantiles_locations = tf.reshape(quantiles_locations, (tf.shape(quantiles_locations)[0], self.num_actions, self.num_atoms))
self.output = quantiles_locations
self.quantiles = tf.placeholder(tf.float32, shape=(None, self.num_atoms), name="quantiles")
self.target = self.quantiles
# only the quantiles of the taken action are taken into account
quantiles_for_used_actions = tf.gather_nd(quantiles_locations, self.actions)
# reorder the output quantiles and the target quantiles as a preparation step for calculating the loss
# the output quantiles vector and the quantile midpoints are tiled as rows of a NxN matrix (N = num quantiles)
# the target quantiles vector is tiled as column of a NxN matrix
theta_i = tf.tile(tf.expand_dims(quantiles_for_used_actions, -1), [1, 1, self.num_atoms])
T_theta_j = tf.tile(tf.expand_dims(self.target, -2), [1, self.num_atoms, 1])
tau_i = tf.tile(tf.expand_dims(self.quantile_midpoints, -1), [1, 1, self.num_atoms])
# Huber loss of T(theta_j) - theta_i
error = T_theta_j - theta_i
abs_error = tf.abs(error)
quadratic = tf.minimum(abs_error, self.huber_loss_interval)
huber_loss = self.huber_loss_interval * (abs_error - quadratic) + 0.5 * quadratic ** 2
# Quantile Huber loss
quantile_huber_loss = tf.abs(tau_i - tf.cast(error < 0, dtype=tf.float32)) * huber_loss
# Quantile regression loss (the probability for each quantile is 1/num_quantiles)
quantile_regression_loss = tf.reduce_sum(quantile_huber_loss) / float(self.num_atoms)
self.loss = quantile_regression_loss
tf.losses.add_loss(self.loss)

View File

@@ -0,0 +1,45 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
from rl_coach.core_types import VStateValue
class VHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='v_head_params'):
super().__init__(parameterized_class=VHead, activation_function=activation_function, name=name)
class VHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function)
self.name = 'v_values_head'
self.return_type = VStateValue
if agent_parameters.network_wrappers[self.network_name.split('/')[0]].replace_mse_with_huber_loss:
self.loss_type = tf.losses.huber_loss
else:
self.loss_type = tf.losses.mean_squared_error
def _build_module(self, input_layer):
# Standard V Network
self.output = tf.layers.dense(input_layer, 1, name='output',
kernel_initializer=normalized_columns_initializer(1.0))