mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
adding support in tensorboard (#52)
* bug-fix in architecture.py where additional fetches would acquire more entries than it should * change in run_test to allow ignoring some test(s)
This commit is contained in:
@@ -21,6 +21,20 @@ from configurations import Preset, MiddlewareTypes
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
def variable_summaries(var):
|
||||
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
|
||||
with tf.name_scope('summaries'):
|
||||
layer_weight_name = '_'.join(var.name.split('/')[-3:])[:-2]
|
||||
|
||||
with tf.name_scope(layer_weight_name):
|
||||
mean = tf.reduce_mean(var)
|
||||
tf.summary.scalar('mean', mean)
|
||||
with tf.name_scope('stddev'):
|
||||
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
|
||||
tf.summary.scalar('stddev', stddev)
|
||||
tf.summary.scalar('max', tf.reduce_max(var))
|
||||
tf.summary.scalar('min', tf.reduce_min(var))
|
||||
tf.summary.histogram('histogram', var)
|
||||
|
||||
class TensorFlowArchitecture(Architecture):
|
||||
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
||||
@@ -44,6 +58,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.curr_rnn_c_in = None
|
||||
self.curr_rnn_h_in = None
|
||||
self.gradients_wrt_inputs = []
|
||||
self.train_writer = None
|
||||
|
||||
self.optimizer_type = self.tp.agent.optimizer_type
|
||||
if self.tp.seed is not None:
|
||||
@@ -75,6 +90,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
for idx, var in enumerate(self.trainable_weights):
|
||||
placeholder = tf.placeholder(tf.float32, shape=var.get_shape(), name=str(idx) + '_holder')
|
||||
self.weights_placeholders.append(placeholder)
|
||||
variable_summaries(var)
|
||||
|
||||
self.update_weights_from_list = [weights.assign(holder) for holder, weights in
|
||||
zip(self.weights_placeholders, self.trainable_weights)]
|
||||
|
||||
@@ -106,12 +123,22 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step)
|
||||
|
||||
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
|
||||
scope=tf.contrib.framework.get_name_scope())
|
||||
self.merged = tf.summary.merge(current_scope_summaries)
|
||||
|
||||
# initialize or restore model
|
||||
if not self.tp.distributed:
|
||||
# Merge all the summaries
|
||||
|
||||
self.init_op = tf.global_variables_initializer()
|
||||
|
||||
if self.sess:
|
||||
self.sess.run(self.init_op)
|
||||
if self.tp.visualization.tensorboard:
|
||||
# Write the merged summaries to the current experiment directory
|
||||
self.train_writer = tf.summary.FileWriter(self.tp.experiment_path + '/tensorboard',
|
||||
self.sess.graph)
|
||||
self.sess.run(self.init_op)
|
||||
|
||||
self.accumulated_gradients = None
|
||||
|
||||
@@ -169,8 +196,12 @@ class TensorFlowArchitecture(Architecture):
|
||||
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
||||
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
||||
|
||||
fetches += [self.merged]
|
||||
|
||||
# get grads
|
||||
result = self.tp.sess.run(fetches, feed_dict=feed_dict)
|
||||
if hasattr(self, 'train_writer') and self.train_writer is not None:
|
||||
self.train_writer.add_summary(result[-1], self.tp.current_episode)
|
||||
|
||||
# extract the fetches
|
||||
norm_unclipped_grads, grads, total_loss, losses = result[:4]
|
||||
@@ -178,7 +209,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
(self.curr_rnn_c_in, self.curr_rnn_h_in) = result[4]
|
||||
fetched_tensors = []
|
||||
if len(additional_fetches) > 0:
|
||||
fetched_tensors = result[additional_fetches_start_idx:]
|
||||
fetched_tensors = result[additional_fetches_start_idx:additional_fetches_start_idx +
|
||||
len(additional_fetches)]
|
||||
|
||||
# accumulate the gradients
|
||||
for idx, grad in enumerate(grads):
|
||||
|
||||
@@ -59,13 +59,17 @@ class ImageEmbedder(InputEmbedder):
|
||||
# same embedder as used in the original DQN paper
|
||||
self.observation_conv1 = tf.layers.conv2d(rescaled_observation_stack,
|
||||
filters=32, kernel_size=(8, 8), strides=(4, 4),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv1')
|
||||
self.observation_conv2 = tf.layers.conv2d(self.observation_conv1,
|
||||
filters=64, kernel_size=(4, 4), strides=(2, 2),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv2')
|
||||
self.observation_conv3 = tf.layers.conv2d(self.observation_conv2,
|
||||
filters=64, kernel_size=(3, 3), strides=(1, 1),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv3'
|
||||
)
|
||||
|
||||
self.output = tf.contrib.layers.flatten(self.observation_conv3)
|
||||
|
||||
@@ -73,28 +77,36 @@ class ImageEmbedder(InputEmbedder):
|
||||
# the embedder used in the CARLA papers
|
||||
self.observation_conv1 = tf.layers.conv2d(rescaled_observation_stack,
|
||||
filters=32, kernel_size=(5, 5), strides=(2, 2),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv1')
|
||||
self.observation_conv2 = tf.layers.conv2d(self.observation_conv1,
|
||||
filters=32, kernel_size=(3, 3), strides=(1, 1),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv2')
|
||||
self.observation_conv3 = tf.layers.conv2d(self.observation_conv2,
|
||||
filters=64, kernel_size=(3, 3), strides=(2, 2),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv3')
|
||||
self.observation_conv4 = tf.layers.conv2d(self.observation_conv3,
|
||||
filters=64, kernel_size=(3, 3), strides=(1, 1),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv4')
|
||||
self.observation_conv5 = tf.layers.conv2d(self.observation_conv4,
|
||||
filters=128, kernel_size=(3, 3), strides=(2, 2),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv5')
|
||||
self.observation_conv6 = tf.layers.conv2d(self.observation_conv5,
|
||||
filters=128, kernel_size=(3, 3), strides=(1, 1),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv6')
|
||||
self.observation_conv7 = tf.layers.conv2d(self.observation_conv6,
|
||||
filters=256, kernel_size=(3, 3), strides=(2, 2),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv7')
|
||||
self.observation_conv8 = tf.layers.conv2d(self.observation_conv7,
|
||||
filters=256, kernel_size=(3, 3), strides=(1, 1),
|
||||
activation=self.activation_function, data_format='channels_last')
|
||||
activation=self.activation_function, data_format='channels_last',
|
||||
name='conv8')
|
||||
|
||||
self.output = tf.contrib.layers.flatten(self.observation_conv8)
|
||||
else:
|
||||
@@ -111,12 +123,16 @@ class VectorEmbedder(InputEmbedder):
|
||||
input_layer = tf.contrib.layers.flatten(self.input)
|
||||
|
||||
if self.embedder_complexity == EmbedderComplexity.Shallow:
|
||||
self.output = tf.layers.dense(input_layer, 256, activation=self.activation_function)
|
||||
self.output = tf.layers.dense(input_layer, 256, activation=self.activation_function,
|
||||
name='fc1')
|
||||
|
||||
elif self.embedder_complexity == EmbedderComplexity.Deep:
|
||||
# the embedder used in the CARLA papers
|
||||
self.observation_fc1 = tf.layers.dense(input_layer, 128, activation=self.activation_function)
|
||||
self.observation_fc2 = tf.layers.dense(self.observation_fc1, 128, activation=self.activation_function)
|
||||
self.output = tf.layers.dense(self.observation_fc2, 128, activation=self.activation_function)
|
||||
self.observation_fc1 = tf.layers.dense(input_layer, 128, activation=self.activation_function,
|
||||
name='fc1')
|
||||
self.observation_fc2 = tf.layers.dense(self.observation_fc1, 128, activation=self.activation_function,
|
||||
name='fc2')
|
||||
self.output = tf.layers.dense(self.observation_fc2, 128, activation=self.activation_function,
|
||||
name='fc3')
|
||||
else:
|
||||
raise ValueError("The defined embedder complexity value is invalid")
|
||||
|
||||
@@ -171,6 +171,8 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
self.losses = tf.losses.get_losses(self.name)
|
||||
self.losses += tf.losses.get_regularization_losses(self.name)
|
||||
self.total_loss = tf.losses.compute_weighted_loss(self.losses, scope=self.name)
|
||||
tf.summary.scalar('total_loss', self.total_loss)
|
||||
|
||||
|
||||
# Learning rate
|
||||
if self.tp.learning_rate_decay_rate != 0:
|
||||
|
||||
@@ -125,14 +125,14 @@ class DuelingQHead(QHead):
|
||||
def _build_module(self, input_layer):
|
||||
# state value tower - V
|
||||
with tf.variable_scope("state_value"):
|
||||
state_value = tf.layers.dense(input_layer, 256, activation=tf.nn.relu)
|
||||
state_value = tf.layers.dense(state_value, 1)
|
||||
state_value = tf.layers.dense(input_layer, 256, activation=tf.nn.relu, name='fc1')
|
||||
state_value = tf.layers.dense(state_value, 1, name='fc2')
|
||||
# state_value = tf.expand_dims(state_value, axis=-1)
|
||||
|
||||
# action advantage tower - A
|
||||
with tf.variable_scope("action_advantage"):
|
||||
action_advantage = tf.layers.dense(input_layer, 256, activation=tf.nn.relu)
|
||||
action_advantage = tf.layers.dense(action_advantage, self.num_actions)
|
||||
action_advantage = tf.layers.dense(input_layer, 256, activation=tf.nn.relu, name='fc1')
|
||||
action_advantage = tf.layers.dense(action_advantage, self.num_actions, name='fc2')
|
||||
action_advantage = action_advantage - tf.reduce_mean(action_advantage)
|
||||
|
||||
# merge to state-action value function Q
|
||||
@@ -177,7 +177,7 @@ class PolicyHead(Head):
|
||||
|
||||
# Policy Head
|
||||
if self.discrete_controls:
|
||||
policy_values = tf.layers.dense(input_layer, self.num_actions)
|
||||
policy_values = tf.layers.dense(input_layer, self.num_actions, name='fc')
|
||||
self.policy_mean = tf.nn.softmax(policy_values, name="policy")
|
||||
|
||||
# define the distributions for the policy and the old policy
|
||||
@@ -186,7 +186,7 @@ class PolicyHead(Head):
|
||||
self.output = self.policy_mean
|
||||
else:
|
||||
# mean
|
||||
policy_values_mean = tf.layers.dense(input_layer, self.num_actions, activation=tf.nn.tanh)
|
||||
policy_values_mean = tf.layers.dense(input_layer, self.num_actions, activation=tf.nn.tanh, name='fc_mean')
|
||||
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
|
||||
|
||||
self.output = [self.policy_mean]
|
||||
@@ -194,7 +194,7 @@ class PolicyHead(Head):
|
||||
# std
|
||||
if self.exploration_policy == 'ContinuousEntropy':
|
||||
policy_values_std = tf.layers.dense(input_layer, self.num_actions,
|
||||
kernel_initializer=normalized_columns_initializer(0.01))
|
||||
kernel_initializer=normalized_columns_initializer(0.01), name='fc_std')
|
||||
self.policy_std = tf.nn.softplus(policy_values_std, name='output_variance') + eps
|
||||
|
||||
self.output.append(self.policy_std)
|
||||
@@ -239,14 +239,15 @@ class MeasurementsPredictionHead(Head):
|
||||
# This is almost exactly the same as Dueling Network but we predict the future measurements for each action
|
||||
# actions expectation tower (expectation stream) - E
|
||||
with tf.variable_scope("expectation_stream"):
|
||||
expectation_stream = tf.layers.dense(input_layer, 256, activation=tf.nn.elu)
|
||||
expectation_stream = tf.layers.dense(expectation_stream, self.multi_step_measurements_size)
|
||||
expectation_stream = tf.layers.dense(input_layer, 256, activation=tf.nn.elu, name='fc1')
|
||||
expectation_stream = tf.layers.dense(expectation_stream, self.multi_step_measurements_size, name='output')
|
||||
expectation_stream = tf.expand_dims(expectation_stream, axis=1)
|
||||
|
||||
# action fine differences tower (action stream) - A
|
||||
with tf.variable_scope("action_stream"):
|
||||
action_stream = tf.layers.dense(input_layer, 256, activation=tf.nn.elu)
|
||||
action_stream = tf.layers.dense(action_stream, self.num_actions * self.multi_step_measurements_size)
|
||||
action_stream = tf.layers.dense(input_layer, 256, activation=tf.nn.elu, name='fc1')
|
||||
action_stream = tf.layers.dense(action_stream, self.num_actions * self.multi_step_measurements_size,
|
||||
name='output')
|
||||
action_stream = tf.reshape(action_stream,
|
||||
(tf.shape(action_stream)[0], self.num_actions, self.multi_step_measurements_size))
|
||||
action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keep_dims=True)
|
||||
@@ -393,7 +394,7 @@ class PPOHead(Head):
|
||||
# Policy Head
|
||||
if self.discrete_controls:
|
||||
self.input = [self.actions, self.old_policy_mean]
|
||||
policy_values = tf.layers.dense(input_layer, self.num_actions)
|
||||
policy_values = tf.layers.dense(input_layer, self.num_actions, name='policy_fc')
|
||||
self.policy_mean = tf.nn.softmax(policy_values, name="policy")
|
||||
|
||||
# define the distributions for the policy and the old policy
|
||||
@@ -488,7 +489,7 @@ class CategoricalQHead(Head):
|
||||
self.actions = tf.placeholder(tf.int32, [None], name="actions")
|
||||
self.input = [self.actions]
|
||||
|
||||
values_distribution = tf.layers.dense(input_layer, self.num_actions * self.num_atoms)
|
||||
values_distribution = tf.layers.dense(input_layer, self.num_actions * self.num_atoms, name='output')
|
||||
values_distribution = tf.reshape(values_distribution, (tf.shape(values_distribution)[0], self.num_actions, self.num_atoms))
|
||||
# softmax on atoms dimension
|
||||
self.output = tf.nn.softmax(values_distribution)
|
||||
@@ -514,7 +515,7 @@ class QuantileRegressionQHead(Head):
|
||||
self.input = [self.actions, self.quantile_midpoints]
|
||||
|
||||
# the output of the head is the N unordered quantile locations {theta_1, ..., theta_N}
|
||||
quantiles_locations = tf.layers.dense(input_layer, self.num_actions * self.num_atoms)
|
||||
quantiles_locations = tf.layers.dense(input_layer, self.num_actions * self.num_atoms, name='output')
|
||||
quantiles_locations = tf.reshape(quantiles_locations, (tf.shape(quantiles_locations)[0], self.num_actions, self.num_atoms))
|
||||
self.output = quantiles_locations
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class LSTM_Embedder(MiddlewareEmbedder):
|
||||
which would definitely be wrong. need to double check the shape
|
||||
"""
|
||||
|
||||
middleware = tf.layers.dense(self.input, 512, activation=self.activation_function)
|
||||
middleware = tf.layers.dense(self.input, 512, activation=self.activation_function, name='fc1')
|
||||
lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple=True)
|
||||
self.c_init = np.zeros((1, lstm_cell.state_size.c), np.float32)
|
||||
self.h_init = np.zeros((1, lstm_cell.state_size.h), np.float32)
|
||||
@@ -70,4 +70,4 @@ class LSTM_Embedder(MiddlewareEmbedder):
|
||||
|
||||
class FC_Embedder(MiddlewareEmbedder):
|
||||
def _build_module(self):
|
||||
self.output = tf.layers.dense(self.input, 512, activation=self.activation_function)
|
||||
self.output = tf.layers.dense(self.input, 512, activation=self.activation_function, name='fc1')
|
||||
|
||||
Reference in New Issue
Block a user