mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Merge pull request #161 from x77a1/master
Changes to avoid memory leak in Rollout worker
This commit is contained in:
@@ -32,6 +32,16 @@ class GlobalVariableSaver(Saver):
|
|||||||
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
|
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
|
||||||
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
|
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
|
||||||
self._variables = [v for v in self._variables if '/target' not in v.name]
|
self._variables = [v for v in self._variables if '/target' not in v.name]
|
||||||
|
|
||||||
|
# Using a placeholder to update the variable during restore to avoid memory leak.
|
||||||
|
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
|
||||||
|
self._variable_placeholders = []
|
||||||
|
self._variable_update_ops = []
|
||||||
|
for v in self._variables:
|
||||||
|
variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape())
|
||||||
|
self._variable_placeholders.append(variable_placeholder)
|
||||||
|
self._variable_update_ops.append(v.assign(variable_placeholder))
|
||||||
|
|
||||||
self._saver = tf.train.Saver(self._variables)
|
self._saver = tf.train.Saver(self._variables)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -66,8 +76,10 @@ class GlobalVariableSaver(Saver):
|
|||||||
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
|
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
|
||||||
new_name = var_name.replace('global/', 'online/')
|
new_name = var_name.replace('global/', 'online/')
|
||||||
variables[new_name] = reader.get_tensor(var_name)
|
variables[new_name] = reader.get_tensor(var_name)
|
||||||
# Assign all variables
|
|
||||||
sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables])
|
# Assign all variables using placeholder
|
||||||
|
placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)}
|
||||||
|
sess.run(self._variable_update_ops, placeholder_dict)
|
||||||
|
|
||||||
def merge(self, other: 'Saver'):
|
def merge(self, other: 'Saver'):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
|
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
|
||||||
|
from rl_coach.core_types import EnvironmentSteps
|
||||||
from rl_coach.utils import get_open_port
|
from rl_coach.utils import get_open_port
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from tensorflow import logging
|
from tensorflow import logging
|
||||||
@@ -41,6 +43,24 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
|
|||||||
experiment_path="./experiments/test"))
|
experiment_path="./experiments/test"))
|
||||||
# graph_manager.improve()
|
# graph_manager.improve()
|
||||||
|
|
||||||
|
# Test for identifying memory leak in restore_checkpoint
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
|
||||||
|
tf.reset_default_graph()
|
||||||
|
from rl_coach.presets.CartPole_DQN import graph_manager
|
||||||
|
assert graph_manager
|
||||||
|
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||||
|
experiment_path="./experiments/test",
|
||||||
|
apply_stop_condition=True))
|
||||||
|
# graph_manager.improve()
|
||||||
|
# graph_manager.evaluate(EnvironmentSteps(1000))
|
||||||
|
# graph_manager.save_checkpoint()
|
||||||
|
#
|
||||||
|
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
|
||||||
|
# while True:
|
||||||
|
# graph_manager.restore_checkpoint()
|
||||||
|
# graph_manager.evaluate(EnvironmentSteps(1000))
|
||||||
|
# gc.collect()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
@@ -48,5 +68,6 @@ if __name__ == '__main__':
|
|||||||
# test_basic_rl_graph_manager_with_ant_a3c()
|
# test_basic_rl_graph_manager_with_ant_a3c()
|
||||||
# test_basic_rl_graph_manager_with_pong_nec()
|
# test_basic_rl_graph_manager_with_pong_nec()
|
||||||
# test_basic_rl_graph_manager_with_cartpole_dqn()
|
# test_basic_rl_graph_manager_with_cartpole_dqn()
|
||||||
|
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
|
||||||
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
|
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
|
||||||
#test_basic_rl_graph_manager_with_doom_basic_dqn()
|
#test_basic_rl_graph_manager_with_doom_basic_dqn()
|
||||||
Reference in New Issue
Block a user