mirror of
https://github.com/gryf/coach.git
synced 2026-02-21 01:05:50 +01:00
Avoid Memory Leak in Rollout worker
ISSUE: When we restore checkpoints, we create new nodes in the Tensorflow graph. This happens when we assign new value (op node) to RefVariable in GlobalVariableSaver. With every restore the size of TF graph increases as new nodes are created and old unused nodes are not removed from the graph. This causes the memory leak in restore_checkpoint codepath. FIX: We reset the Tensorflow graph and recreate the Global, Online and Target networks on every restore. This ensures that the old unused nodes in TF graph is dropped.
This commit is contained in:
@@ -953,6 +953,12 @@ class Agent(AgentInterface):
|
||||
self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
if self.ap.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
import tensorflow as tf
|
||||
tf.reset_default_graph()
|
||||
# Recreate all the networks of the agent
|
||||
self.networks = self.create_networks()
|
||||
|
||||
# no output filters currently have an internal state to restore
|
||||
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user