mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
bug fixes for clippedppo and checkpoints
This commit is contained in:
@@ -125,9 +125,10 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step)
|
||||
|
||||
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
|
||||
scope=tf.contrib.framework.get_name_scope())
|
||||
self.merged = tf.summary.merge(current_scope_summaries)
|
||||
if self.tp.visualization.tensorboard:
|
||||
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
|
||||
scope=tf.contrib.framework.get_name_scope())
|
||||
self.merged = tf.summary.merge(current_scope_summaries)
|
||||
|
||||
# initialize or restore model
|
||||
if not self.tp.distributed:
|
||||
@@ -197,7 +198,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
||||
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
||||
|
||||
fetches += [self.merged]
|
||||
if self.tp.visualization.tensorboard:
|
||||
fetches += [self.merged]
|
||||
|
||||
# get grads
|
||||
result = self.tp.sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
Reference in New Issue
Block a user