1
0
mirror of https://github.com/gryf/coach.git synced 2026-01-29 11:35:51 +01:00

Avoid Memory Leak in Rollout worker

ISSUE: When we restore checkpoints, we create new nodes in the
Tensorflow graph. This happens when we assign new value (op node) to
RefVariable in GlobalVariableSaver. With every restore the size of TF
graph increases as new nodes are created and old unused nodes are not
removed from the graph. This causes the memory leak in
restore_checkpoint codepath.

FIX: We use TF placeholder to update the variables which avoids the
memory leak.
This commit is contained in:
Gourav Roy
2019-01-02 23:06:44 -08:00
parent c377363e50
commit 619ea0944e
2 changed files with 34 additions and 2 deletions

View File

@@ -1,8 +1,10 @@
import gc
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import tensorflow as tf
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
from rl_coach.core_types import EnvironmentSteps
from rl_coach.utils import get_open_port
from multiprocessing import Process
from tensorflow import logging
@@ -41,6 +43,24 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
experiment_path="./experiments/test"))
# graph_manager.improve()
# Test for identifying memory leak in restore_checkpoint
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
tf.reset_default_graph()
from rl_coach.presets.CartPole_DQN import graph_manager
assert graph_manager
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test",
apply_stop_condition=True))
# graph_manager.improve()
# graph_manager.evaluate(EnvironmentSteps(1000))
# graph_manager.save_checkpoint()
#
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
# while True:
# graph_manager.restore_checkpoint()
# graph_manager.evaluate(EnvironmentSteps(1000))
# gc.collect()
if __name__ == '__main__':
pass
@@ -48,5 +68,6 @@ if __name__ == '__main__':
# test_basic_rl_graph_manager_with_ant_a3c()
# test_basic_rl_graph_manager_with_pong_nec()
# test_basic_rl_graph_manager_with_cartpole_dqn()
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
#test_basic_rl_graph_manager_with_doom_basic_dqn()