From b1e9ea48d86807382c5feca0d18a6bf71f5caa03 Mon Sep 17 00:00:00 2001 From: Gourav Roy Date: Thu, 3 Jan 2019 15:08:34 -0800 Subject: [PATCH] Refactored GlobalVariableSaver --- .../architectures/tensorflow_components/savers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/rl_coach/architectures/tensorflow_components/savers.py b/rl_coach/architectures/tensorflow_components/savers.py index 0f7ddbf..67c0c8b 100644 --- a/rl_coach/architectures/tensorflow_components/savers.py +++ b/rl_coach/architectures/tensorflow_components/savers.py @@ -35,10 +35,12 @@ class GlobalVariableSaver(Saver): # Using a placeholder to update the variable during restore to avoid memory leak. # Ref: https://github.com/tensorflow/tensorflow/issues/4151 - self._variable_placeholders = [tf.placeholder(v.dtype, shape=v.get_shape()) for v in self._variables] + self._variable_placeholders = [] self._variable_update_ops = [] - for i in range(len(self._variables)): - self._variable_update_ops.append(self._variables[i].assign(self._variable_placeholders[i])) + for v in self._variables: + variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape()) + self._variable_placeholders.append(variable_placeholder) + self._variable_update_ops.append(v.assign(variable_placeholder)) self._saver = tf.train.Saver(self._variables) @@ -76,9 +78,8 @@ class GlobalVariableSaver(Saver): variables[new_name] = reader.get_tensor(var_name) # Assign all variables using placeholder - for i in range(len(self._variables)): - variable_name = self._variables[i].name.split(':')[0] - sess.run(self._variable_update_ops[i], {self._variable_placeholders[i]: variables[variable_name]}) + placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)} + sess.run(self._variable_update_ops, placeholder_dict) def merge(self, other: 'Saver'): """