mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Revert "Changes to avoid memory leak in rollout worker"
This reverts commit 801aed5e10.
This commit is contained in:
@@ -562,8 +562,7 @@ class GraphManager(object):
|
||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
||||
else:
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER):
|
||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||
|
||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
import gc
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
import tensorflow as tf
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks, RunType
|
||||
from rl_coach.memories.backend.memory import MemoryBackendParameters
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
|
||||
from rl_coach.utils import get_open_port
|
||||
from multiprocessing import Process
|
||||
from tensorflow import logging
|
||||
@@ -43,34 +41,12 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
# Test for identifying memory leak in restore_checkpoint
|
||||
@pytest.mark.unit_test
|
||||
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
|
||||
tf.reset_default_graph()
|
||||
from rl_coach.presets.CartPole_DQN import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test",
|
||||
apply_stop_condition=True))
|
||||
graph_manager.improve()
|
||||
graph_manager.save_checkpoint()
|
||||
|
||||
graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
|
||||
graph_manager.agent_params.memory.register_var('memory_backend_params',
|
||||
MemoryBackendParameters(store_type=None,
|
||||
orchestrator_type=None,
|
||||
run_type=str(RunType.ROLLOUT_WORKER)))
|
||||
while True:
|
||||
graph_manager.restore_checkpoint()
|
||||
gc.collect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
# test_basic_rl_graph_manager_with_pong_a3c()
|
||||
# test_basic_rl_graph_manager_with_ant_a3c()
|
||||
# test_basic_rl_graph_manager_with_pong_nec()
|
||||
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
|
||||
# test_basic_rl_graph_manager_with_cartpole_dqn()
|
||||
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
|
||||
#test_basic_rl_graph_manager_with_doom_basic_dqn()
|
||||
Reference in New Issue
Block a user