1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

Merge pull request #161 from x77a1/master

Changes to avoid memory leak in Rollout worker
This commit is contained in:
Ajay Deshpande
2019-01-03 21:15:04 -08:00
committed by GitHub
2 changed files with 35 additions and 2 deletions

View File

@@ -32,6 +32,16 @@ class GlobalVariableSaver(Saver):
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list # target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow. # the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
self._variables = [v for v in self._variables if '/target' not in v.name] self._variables = [v for v in self._variables if '/target' not in v.name]
# Using a placeholder to update the variable during restore to avoid memory leak.
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
self._variable_placeholders = []
self._variable_update_ops = []
for v in self._variables:
variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape())
self._variable_placeholders.append(variable_placeholder)
self._variable_update_ops.append(v.assign(variable_placeholder))
self._saver = tf.train.Saver(self._variables) self._saver = tf.train.Saver(self._variables)
@property @property
@@ -66,8 +76,10 @@ class GlobalVariableSaver(Saver):
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here? # TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
new_name = var_name.replace('global/', 'online/') new_name = var_name.replace('global/', 'online/')
variables[new_name] = reader.get_tensor(var_name) variables[new_name] = reader.get_tensor(var_name)
# Assign all variables
sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables]) # Assign all variables using placeholder
placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)}
sess.run(self._variable_update_ops, placeholder_dict)
def merge(self, other: 'Saver'): def merge(self, other: 'Saver'):
""" """

View File

@@ -1,8 +1,10 @@
import gc
import os import os
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import tensorflow as tf import tensorflow as tf
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
from rl_coach.core_types import EnvironmentSteps
from rl_coach.utils import get_open_port from rl_coach.utils import get_open_port
from multiprocessing import Process from multiprocessing import Process
from tensorflow import logging from tensorflow import logging
@@ -41,6 +43,24 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
experiment_path="./experiments/test")) experiment_path="./experiments/test"))
# graph_manager.improve() # graph_manager.improve()
# Test for identifying memory leak in restore_checkpoint
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
tf.reset_default_graph()
from rl_coach.presets.CartPole_DQN import graph_manager
assert graph_manager
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test",
apply_stop_condition=True))
# graph_manager.improve()
# graph_manager.evaluate(EnvironmentSteps(1000))
# graph_manager.save_checkpoint()
#
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
# while True:
# graph_manager.restore_checkpoint()
# graph_manager.evaluate(EnvironmentSteps(1000))
# gc.collect()
if __name__ == '__main__': if __name__ == '__main__':
pass pass
@@ -48,5 +68,6 @@ if __name__ == '__main__':
# test_basic_rl_graph_manager_with_ant_a3c() # test_basic_rl_graph_manager_with_ant_a3c()
# test_basic_rl_graph_manager_with_pong_nec() # test_basic_rl_graph_manager_with_pong_nec()
# test_basic_rl_graph_manager_with_cartpole_dqn() # test_basic_rl_graph_manager_with_cartpole_dqn()
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c() #test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
#test_basic_rl_graph_manager_with_doom_basic_dqn() #test_basic_rl_graph_manager_with_doom_basic_dqn()