mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
adding support in tensorboard (#52)
* bug-fix in architecture.py where additional fetches would acquire more entries than it should * change in run_test to allow ignoring some test(s)
This commit is contained in:
@@ -21,6 +21,20 @@ from configurations import Preset, MiddlewareTypes
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
def variable_summaries(var):
|
||||
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
|
||||
with tf.name_scope('summaries'):
|
||||
layer_weight_name = '_'.join(var.name.split('/')[-3:])[:-2]
|
||||
|
||||
with tf.name_scope(layer_weight_name):
|
||||
mean = tf.reduce_mean(var)
|
||||
tf.summary.scalar('mean', mean)
|
||||
with tf.name_scope('stddev'):
|
||||
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
|
||||
tf.summary.scalar('stddev', stddev)
|
||||
tf.summary.scalar('max', tf.reduce_max(var))
|
||||
tf.summary.scalar('min', tf.reduce_min(var))
|
||||
tf.summary.histogram('histogram', var)
|
||||
|
||||
class TensorFlowArchitecture(Architecture):
|
||||
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
||||
@@ -44,6 +58,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.curr_rnn_c_in = None
|
||||
self.curr_rnn_h_in = None
|
||||
self.gradients_wrt_inputs = []
|
||||
self.train_writer = None
|
||||
|
||||
self.optimizer_type = self.tp.agent.optimizer_type
|
||||
if self.tp.seed is not None:
|
||||
@@ -75,6 +90,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
for idx, var in enumerate(self.trainable_weights):
|
||||
placeholder = tf.placeholder(tf.float32, shape=var.get_shape(), name=str(idx) + '_holder')
|
||||
self.weights_placeholders.append(placeholder)
|
||||
variable_summaries(var)
|
||||
|
||||
self.update_weights_from_list = [weights.assign(holder) for holder, weights in
|
||||
zip(self.weights_placeholders, self.trainable_weights)]
|
||||
|
||||
@@ -106,12 +123,22 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step)
|
||||
|
||||
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
|
||||
scope=tf.contrib.framework.get_name_scope())
|
||||
self.merged = tf.summary.merge(current_scope_summaries)
|
||||
|
||||
# initialize or restore model
|
||||
if not self.tp.distributed:
|
||||
# Merge all the summaries
|
||||
|
||||
self.init_op = tf.global_variables_initializer()
|
||||
|
||||
if self.sess:
|
||||
self.sess.run(self.init_op)
|
||||
if self.tp.visualization.tensorboard:
|
||||
# Write the merged summaries to the current experiment directory
|
||||
self.train_writer = tf.summary.FileWriter(self.tp.experiment_path + '/tensorboard',
|
||||
self.sess.graph)
|
||||
self.sess.run(self.init_op)
|
||||
|
||||
self.accumulated_gradients = None
|
||||
|
||||
@@ -169,8 +196,12 @@ class TensorFlowArchitecture(Architecture):
|
||||
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
||||
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
||||
|
||||
fetches += [self.merged]
|
||||
|
||||
# get grads
|
||||
result = self.tp.sess.run(fetches, feed_dict=feed_dict)
|
||||
if hasattr(self, 'train_writer') and self.train_writer is not None:
|
||||
self.train_writer.add_summary(result[-1], self.tp.current_episode)
|
||||
|
||||
# extract the fetches
|
||||
norm_unclipped_grads, grads, total_loss, losses = result[:4]
|
||||
@@ -178,7 +209,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
(self.curr_rnn_c_in, self.curr_rnn_h_in) = result[4]
|
||||
fetched_tensors = []
|
||||
if len(additional_fetches) > 0:
|
||||
fetched_tensors = result[additional_fetches_start_idx:]
|
||||
fetched_tensors = result[additional_fetches_start_idx:additional_fetches_start_idx +
|
||||
len(additional_fetches)]
|
||||
|
||||
# accumulate the gradients
|
||||
for idx, grad in enumerate(grads):
|
||||
|
||||
Reference in New Issue
Block a user