mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Refactored GlobalVariableSaver
This commit is contained in:
@@ -35,10 +35,12 @@ class GlobalVariableSaver(Saver):
|
||||
|
||||
# Using a placeholder to update the variable during restore to avoid memory leak.
|
||||
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
|
||||
self._variable_placeholders = [tf.placeholder(v.dtype, shape=v.get_shape()) for v in self._variables]
|
||||
self._variable_placeholders = []
|
||||
self._variable_update_ops = []
|
||||
for i in range(len(self._variables)):
|
||||
self._variable_update_ops.append(self._variables[i].assign(self._variable_placeholders[i]))
|
||||
for v in self._variables:
|
||||
variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape())
|
||||
self._variable_placeholders.append(variable_placeholder)
|
||||
self._variable_update_ops.append(v.assign(variable_placeholder))
|
||||
|
||||
self._saver = tf.train.Saver(self._variables)
|
||||
|
||||
@@ -76,9 +78,8 @@ class GlobalVariableSaver(Saver):
|
||||
variables[new_name] = reader.get_tensor(var_name)
|
||||
|
||||
# Assign all variables using placeholder
|
||||
for i in range(len(self._variables)):
|
||||
variable_name = self._variables[i].name.split(':')[0]
|
||||
sess.run(self._variable_update_ops[i], {self._variable_placeholders[i]: variables[variable_name]})
|
||||
placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)}
|
||||
sess.run(self._variable_update_ops, placeholder_dict)
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user