1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Changes to avoid memory leak in rollout worker

Currently in rollout worker, we call restore_checkpoint repeatedly to load the latest model in memory. The restore checkpoint functions calls checkpoint_saver. Checkpoint saver uses GlobalVariablesSaver which does not release the references of the previous model variables. This leads to the situation where the memory keeps on growing before crashing the rollout worker.

This change avoid using the checkpoint saver in the rollout worker as I believe it is not needed in this code path.

Also added a test to easily reproduce the issue using CartPole example. We were also seeing this issue with the AWS DeepRacer implementation and the current implementation avoid the memory leak there as well.
This commit is contained in:
gouravr
2018-12-15 12:26:31 -08:00
parent e08accdc22
commit 801aed5e10
2 changed files with 27 additions and 2 deletions

View File

@@ -562,7 +562,8 @@ class GraphManager(object):
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else: else:
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER):
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]

View File

@@ -1,8 +1,10 @@
import os import os
import sys import sys
import gc
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import tensorflow as tf import tensorflow as tf
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks, RunType
from rl_coach.memories.backend.memory import MemoryBackendParameters
from rl_coach.utils import get_open_port from rl_coach.utils import get_open_port
from multiprocessing import Process from multiprocessing import Process
from tensorflow import logging from tensorflow import logging
@@ -41,12 +43,34 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
experiment_path="./experiments/test")) experiment_path="./experiments/test"))
# graph_manager.improve() # graph_manager.improve()
# Test for identifying memory leak in restore_checkpoint
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
tf.reset_default_graph()
from rl_coach.presets.CartPole_DQN import graph_manager
assert graph_manager
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test",
apply_stop_condition=True))
graph_manager.improve()
graph_manager.save_checkpoint()
graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
graph_manager.agent_params.memory.register_var('memory_backend_params',
MemoryBackendParameters(store_type=None,
orchestrator_type=None,
run_type=str(RunType.ROLLOUT_WORKER)))
while True:
graph_manager.restore_checkpoint()
gc.collect()
if __name__ == '__main__': if __name__ == '__main__':
pass pass
# test_basic_rl_graph_manager_with_pong_a3c() # test_basic_rl_graph_manager_with_pong_a3c()
# test_basic_rl_graph_manager_with_ant_a3c() # test_basic_rl_graph_manager_with_ant_a3c()
# test_basic_rl_graph_manager_with_pong_nec() # test_basic_rl_graph_manager_with_pong_nec()
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
# test_basic_rl_graph_manager_with_cartpole_dqn() # test_basic_rl_graph_manager_with_cartpole_dqn()
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c() #test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
#test_basic_rl_graph_manager_with_doom_basic_dqn() #test_basic_rl_graph_manager_with_doom_basic_dqn()