1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

adding support in tensorboard (#52)

* bug-fix in architecture.py where additional fetches would acquire more entries than it should
* change in run_test to allow ignoring some test(s)
This commit is contained in:
Gal Leibovich
2018-02-05 15:21:49 +02:00
committed by GitHub
parent a8d5fb7bdf
commit 7c8962c991
10 changed files with 107 additions and 36 deletions

View File

@@ -21,6 +21,20 @@ from configurations import Preset, MiddlewareTypes
import numpy as np
import time
def variable_summaries(var):
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
with tf.name_scope('summaries'):
layer_weight_name = '_'.join(var.name.split('/')[-3:])[:-2]
with tf.name_scope(layer_weight_name):
mean = tf.reduce_mean(var)
tf.summary.scalar('mean', mean)
with tf.name_scope('stddev'):
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
tf.summary.scalar('stddev', stddev)
tf.summary.scalar('max', tf.reduce_max(var))
tf.summary.scalar('min', tf.reduce_min(var))
tf.summary.histogram('histogram', var)
class TensorFlowArchitecture(Architecture):
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
@@ -44,6 +58,7 @@ class TensorFlowArchitecture(Architecture):
self.curr_rnn_c_in = None
self.curr_rnn_h_in = None
self.gradients_wrt_inputs = []
self.train_writer = None
self.optimizer_type = self.tp.agent.optimizer_type
if self.tp.seed is not None:
@@ -75,6 +90,8 @@ class TensorFlowArchitecture(Architecture):
for idx, var in enumerate(self.trainable_weights):
placeholder = tf.placeholder(tf.float32, shape=var.get_shape(), name=str(idx) + '_holder')
self.weights_placeholders.append(placeholder)
variable_summaries(var)
self.update_weights_from_list = [weights.assign(holder) for holder, weights in
zip(self.weights_placeholders, self.trainable_weights)]
@@ -106,12 +123,22 @@ class TensorFlowArchitecture(Architecture):
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step)
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
scope=tf.contrib.framework.get_name_scope())
self.merged = tf.summary.merge(current_scope_summaries)
# initialize or restore model
if not self.tp.distributed:
# Merge all the summaries
self.init_op = tf.global_variables_initializer()
if self.sess:
self.sess.run(self.init_op)
if self.tp.visualization.tensorboard:
# Write the merged summaries to the current experiment directory
self.train_writer = tf.summary.FileWriter(self.tp.experiment_path + '/tensorboard',
self.sess.graph)
self.sess.run(self.init_op)
self.accumulated_gradients = None
@@ -169,8 +196,12 @@ class TensorFlowArchitecture(Architecture):
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
fetches += [self.merged]
# get grads
result = self.tp.sess.run(fetches, feed_dict=feed_dict)
if hasattr(self, 'train_writer') and self.train_writer is not None:
self.train_writer.add_summary(result[-1], self.tp.current_episode)
# extract the fetches
norm_unclipped_grads, grads, total_loss, losses = result[:4]
@@ -178,7 +209,8 @@ class TensorFlowArchitecture(Architecture):
(self.curr_rnn_c_in, self.curr_rnn_h_in) = result[4]
fetched_tensors = []
if len(additional_fetches) > 0:
fetched_tensors = result[additional_fetches_start_idx:]
fetched_tensors = result[additional_fetches_start_idx:additional_fetches_start_idx +
len(additional_fetches)]
# accumulate the gradients
for idx, grad in enumerate(grads):