mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Revert "Changes to avoid memory leak in rollout worker"
This reverts commit 801aed5e10.
This commit is contained in:
@@ -562,8 +562,7 @@ class GraphManager(object):
|
|||||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
||||||
else:
|
else:
|
||||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||||
if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER):
|
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
|
||||||
|
|
||||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import gc
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks, RunType
|
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
|
||||||
from rl_coach.memories.backend.memory import MemoryBackendParameters
|
|
||||||
from rl_coach.utils import get_open_port
|
from rl_coach.utils import get_open_port
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from tensorflow import logging
|
from tensorflow import logging
|
||||||
@@ -43,34 +41,12 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
|
|||||||
experiment_path="./experiments/test"))
|
experiment_path="./experiments/test"))
|
||||||
# graph_manager.improve()
|
# graph_manager.improve()
|
||||||
|
|
||||||
# Test for identifying memory leak in restore_checkpoint
|
|
||||||
@pytest.mark.unit_test
|
|
||||||
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
|
|
||||||
tf.reset_default_graph()
|
|
||||||
from rl_coach.presets.CartPole_DQN import graph_manager
|
|
||||||
assert graph_manager
|
|
||||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
|
||||||
experiment_path="./experiments/test",
|
|
||||||
apply_stop_condition=True))
|
|
||||||
graph_manager.improve()
|
|
||||||
graph_manager.save_checkpoint()
|
|
||||||
|
|
||||||
graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
|
|
||||||
graph_manager.agent_params.memory.register_var('memory_backend_params',
|
|
||||||
MemoryBackendParameters(store_type=None,
|
|
||||||
orchestrator_type=None,
|
|
||||||
run_type=str(RunType.ROLLOUT_WORKER)))
|
|
||||||
while True:
|
|
||||||
graph_manager.restore_checkpoint()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
# test_basic_rl_graph_manager_with_pong_a3c()
|
# test_basic_rl_graph_manager_with_pong_a3c()
|
||||||
# test_basic_rl_graph_manager_with_ant_a3c()
|
# test_basic_rl_graph_manager_with_ant_a3c()
|
||||||
# test_basic_rl_graph_manager_with_pong_nec()
|
# test_basic_rl_graph_manager_with_pong_nec()
|
||||||
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
|
|
||||||
# test_basic_rl_graph_manager_with_cartpole_dqn()
|
# test_basic_rl_graph_manager_with_cartpole_dqn()
|
||||||
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
|
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
|
||||||
#test_basic_rl_graph_manager_with_doom_basic_dqn()
|
#test_basic_rl_graph_manager_with_doom_basic_dqn()
|
||||||
Reference in New Issue
Block a user