1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Revert "Changes to avoid memory leak in rollout worker"

This reverts commit 801aed5e10.
This commit is contained in:
Gourav Roy
2019-01-02 22:37:12 -08:00
parent 779d3694b4
commit c377363e50
2 changed files with 2 additions and 27 deletions

View File

@@ -562,8 +562,7 @@ class GraphManager(object):
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else:
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER):
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]

View File

@@ -1,10 +1,8 @@
import os
import sys
import gc
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import tensorflow as tf
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks, RunType
from rl_coach.memories.backend.memory import MemoryBackendParameters
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
from rl_coach.utils import get_open_port
from multiprocessing import Process
from tensorflow import logging
@@ -43,34 +41,12 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
experiment_path="./experiments/test"))
# graph_manager.improve()
# Test for identifying memory leak in restore_checkpoint
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
tf.reset_default_graph()
from rl_coach.presets.CartPole_DQN import graph_manager
assert graph_manager
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test",
apply_stop_condition=True))
graph_manager.improve()
graph_manager.save_checkpoint()
graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
graph_manager.agent_params.memory.register_var('memory_backend_params',
MemoryBackendParameters(store_type=None,
orchestrator_type=None,
run_type=str(RunType.ROLLOUT_WORKER)))
while True:
graph_manager.restore_checkpoint()
gc.collect()
if __name__ == '__main__':
pass
# test_basic_rl_graph_manager_with_pong_a3c()
# test_basic_rl_graph_manager_with_ant_a3c()
# test_basic_rl_graph_manager_with_pong_nec()
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
# test_basic_rl_graph_manager_with_cartpole_dqn()
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
#test_basic_rl_graph_manager_with_doom_basic_dqn()