1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Refactored GlobalVariableSaver

This commit is contained in:
Gourav Roy
2019-01-03 15:08:34 -08:00
parent 619ea0944e
commit b1e9ea48d8

View File

@@ -35,10 +35,12 @@ class GlobalVariableSaver(Saver):
# Using a placeholder to update the variable during restore to avoid memory leak. # Using a placeholder to update the variable during restore to avoid memory leak.
# Ref: https://github.com/tensorflow/tensorflow/issues/4151 # Ref: https://github.com/tensorflow/tensorflow/issues/4151
self._variable_placeholders = [tf.placeholder(v.dtype, shape=v.get_shape()) for v in self._variables] self._variable_placeholders = []
self._variable_update_ops = [] self._variable_update_ops = []
for i in range(len(self._variables)): for v in self._variables:
self._variable_update_ops.append(self._variables[i].assign(self._variable_placeholders[i])) variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape())
self._variable_placeholders.append(variable_placeholder)
self._variable_update_ops.append(v.assign(variable_placeholder))
self._saver = tf.train.Saver(self._variables) self._saver = tf.train.Saver(self._variables)
@@ -76,9 +78,8 @@ class GlobalVariableSaver(Saver):
variables[new_name] = reader.get_tensor(var_name) variables[new_name] = reader.get_tensor(var_name)
# Assign all variables using placeholder # Assign all variables using placeholder
for i in range(len(self._variables)): placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)}
variable_name = self._variables[i].name.split(':')[0] sess.run(self._variable_update_ops, placeholder_dict)
sess.run(self._variable_update_ops[i], {self._variable_placeholders[i]: variables[variable_name]})
def merge(self, other: 'Saver'): def merge(self, other: 'Saver'):
""" """