1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

imporved API for getting / setting variables within the graph

This commit is contained in:
Itai Caspi
2017-10-25 16:07:58 +03:00
parent e33b0e8534
commit 1918f16079
5 changed files with 45 additions and 17 deletions

View File

@@ -288,3 +288,21 @@ class TensorFlowArchitecture(Architecture):
"""
summary_writer = tf.summary.FileWriter(summary_dir)
summary_writer.add_graph(self.sess.graph)
def get_variable_value(self, variable):
"""
Get the value of a variable from the graph
:param variable: the variable
:return: the value of the variable
"""
return self.sess.run(variable)
def set_variable_value(self, assign_op, value, placeholder=None):
"""
Updates the value of a variable.
This requires having an assign operation for the variable, and a placeholder which will provide the value
:param assign_op: an assign operation for the variable
:param value: a value to set the variable to
:param placeholder: a placeholder to hold the given value for injecting it into the variable
"""
self.sess.run(assign_op, feed_dict={placeholder: value})