mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Merge pull request #161 from x77a1/master
Changes to avoid memory leak in Rollout worker
This commit is contained in:
@@ -32,6 +32,16 @@ class GlobalVariableSaver(Saver):
|
||||
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
|
||||
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
|
||||
self._variables = [v for v in self._variables if '/target' not in v.name]
|
||||
|
||||
# Using a placeholder to update the variable during restore to avoid memory leak.
|
||||
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
|
||||
self._variable_placeholders = []
|
||||
self._variable_update_ops = []
|
||||
for v in self._variables:
|
||||
variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape())
|
||||
self._variable_placeholders.append(variable_placeholder)
|
||||
self._variable_update_ops.append(v.assign(variable_placeholder))
|
||||
|
||||
self._saver = tf.train.Saver(self._variables)
|
||||
|
||||
@property
|
||||
@@ -66,8 +76,10 @@ class GlobalVariableSaver(Saver):
|
||||
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
|
||||
new_name = var_name.replace('global/', 'online/')
|
||||
variables[new_name] = reader.get_tensor(var_name)
|
||||
# Assign all variables
|
||||
sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables])
|
||||
|
||||
# Assign all variables using placeholder
|
||||
placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)}
|
||||
sess.run(self._variable_update_ops, placeholder_dict)
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
import tensorflow as tf
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.utils import get_open_port
|
||||
from multiprocessing import Process
|
||||
from tensorflow import logging
|
||||
@@ -41,6 +43,24 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
# Test for identifying memory leak in restore_checkpoint
|
||||
@pytest.mark.unit_test
|
||||
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
|
||||
tf.reset_default_graph()
|
||||
from rl_coach.presets.CartPole_DQN import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test",
|
||||
apply_stop_condition=True))
|
||||
# graph_manager.improve()
|
||||
# graph_manager.evaluate(EnvironmentSteps(1000))
|
||||
# graph_manager.save_checkpoint()
|
||||
#
|
||||
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
|
||||
# while True:
|
||||
# graph_manager.restore_checkpoint()
|
||||
# graph_manager.evaluate(EnvironmentSteps(1000))
|
||||
# gc.collect()
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
@@ -48,5 +68,6 @@ if __name__ == '__main__':
|
||||
# test_basic_rl_graph_manager_with_ant_a3c()
|
||||
# test_basic_rl_graph_manager_with_pong_nec()
|
||||
# test_basic_rl_graph_manager_with_cartpole_dqn()
|
||||
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
|
||||
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
|
||||
#test_basic_rl_graph_manager_with_doom_basic_dqn()
|
||||
Reference in New Issue
Block a user