mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Avoid Memory Leak in Rollout worker
ISSUE: When we restore checkpoints, we create new nodes in the Tensorflow graph. This happens when we assign new value (op node) to RefVariable in GlobalVariableSaver. With every restore the size of TF graph increases as new nodes are created and old unused nodes are not removed from the graph. This causes the memory leak in restore_checkpoint codepath. FIX: We use TF placeholder to update the variables which avoids the memory leak.
This commit is contained in:
@@ -32,6 +32,14 @@ class GlobalVariableSaver(Saver):
|
||||
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
|
||||
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
|
||||
self._variables = [v for v in self._variables if '/target' not in v.name]
|
||||
|
||||
# Using a placeholder to update the variable during restore to avoid memory leak.
|
||||
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
|
||||
self._variable_placeholders = [tf.placeholder(v.dtype, shape=v.get_shape()) for v in self._variables]
|
||||
self._variable_update_ops = []
|
||||
for i in range(len(self._variables)):
|
||||
self._variable_update_ops.append(self._variables[i].assign(self._variable_placeholders[i]))
|
||||
|
||||
self._saver = tf.train.Saver(self._variables)
|
||||
|
||||
@property
|
||||
@@ -66,8 +74,11 @@ class GlobalVariableSaver(Saver):
|
||||
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
|
||||
new_name = var_name.replace('global/', 'online/')
|
||||
variables[new_name] = reader.get_tensor(var_name)
|
||||
# Assign all variables
|
||||
sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables])
|
||||
|
||||
# Assign all variables using placeholder
|
||||
for i in range(len(self._variables)):
|
||||
variable_name = self._variables[i].name.split(':')[0]
|
||||
sess.run(self._variable_update_ops[i], {self._variable_placeholders[i]: variables[variable_name]})
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user