mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Changes to avoid memory leak in rollout worker
Currently in rollout worker, we call restore_checkpoint repeatedly to load the latest model in memory. The restore checkpoint functions calls checkpoint_saver. Checkpoint saver uses GlobalVariablesSaver which does not release the references of the previous model variables. This leads to the situation where the memory keeps on growing before crashing the rollout worker. This change avoid using the checkpoint saver in the rollout worker as I believe it is not needed in this code path. Also added a test to easily reproduce the issue using CartPole example. We were also seeing this issue with the AWS DeepRacer implementation and the current implementation avoid the memory leak there as well.
This commit is contained in:
@@ -562,7 +562,8 @@ class GraphManager(object):
|
|||||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
||||||
else:
|
else:
|
||||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER):
|
||||||
|
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||||
|
|
||||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import gc
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
|
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks, RunType
|
||||||
|
from rl_coach.memories.backend.memory import MemoryBackendParameters
|
||||||
from rl_coach.utils import get_open_port
|
from rl_coach.utils import get_open_port
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from tensorflow import logging
|
from tensorflow import logging
|
||||||
@@ -41,12 +43,34 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
|
|||||||
experiment_path="./experiments/test"))
|
experiment_path="./experiments/test"))
|
||||||
# graph_manager.improve()
|
# graph_manager.improve()
|
||||||
|
|
||||||
|
# Test for identifying memory leak in restore_checkpoint
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
|
||||||
|
tf.reset_default_graph()
|
||||||
|
from rl_coach.presets.CartPole_DQN import graph_manager
|
||||||
|
assert graph_manager
|
||||||
|
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||||
|
experiment_path="./experiments/test",
|
||||||
|
apply_stop_condition=True))
|
||||||
|
graph_manager.improve()
|
||||||
|
graph_manager.save_checkpoint()
|
||||||
|
|
||||||
|
graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
|
||||||
|
graph_manager.agent_params.memory.register_var('memory_backend_params',
|
||||||
|
MemoryBackendParameters(store_type=None,
|
||||||
|
orchestrator_type=None,
|
||||||
|
run_type=str(RunType.ROLLOUT_WORKER)))
|
||||||
|
while True:
|
||||||
|
graph_manager.restore_checkpoint()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
# test_basic_rl_graph_manager_with_pong_a3c()
|
# test_basic_rl_graph_manager_with_pong_a3c()
|
||||||
# test_basic_rl_graph_manager_with_ant_a3c()
|
# test_basic_rl_graph_manager_with_ant_a3c()
|
||||||
# test_basic_rl_graph_manager_with_pong_nec()
|
# test_basic_rl_graph_manager_with_pong_nec()
|
||||||
|
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
|
||||||
# test_basic_rl_graph_manager_with_cartpole_dqn()
|
# test_basic_rl_graph_manager_with_cartpole_dqn()
|
||||||
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
|
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
|
||||||
#test_basic_rl_graph_manager_with_doom_basic_dqn()
|
#test_basic_rl_graph_manager_with_doom_basic_dqn()
|
||||||
Reference in New Issue
Block a user