From 2291cee2c6664da156bf7a1c1bdf99144e020f8b Mon Sep 17 00:00:00 2001 From: Zach Dwiel Date: Thu, 4 Apr 2019 11:09:19 -0400 Subject: [PATCH] allow serializing from/to arrays/str from GlobalVariableSaver (#285) --- .../tensorflow_components/savers.py | 88 +++++++++++++++---- rl_coach/tests/test_global_variable_saver.py | 79 +++++++++++++++++ 2 files changed, 148 insertions(+), 19 deletions(-) create mode 100644 rl_coach/tests/test_global_variable_saver.py diff --git a/rl_coach/architectures/tensorflow_components/savers.py b/rl_coach/architectures/tensorflow_components/savers.py index ea5b1b8..531c523 100644 --- a/rl_coach/architectures/tensorflow_components/savers.py +++ b/rl_coach/architectures/tensorflow_components/savers.py @@ -14,10 +14,11 @@ # limitations under the License. # - -from typing import Any, List +import pickle +from typing import Any, List, Dict import tensorflow as tf +import numpy as np from rl_coach.saver import Saver @@ -28,10 +29,10 @@ class GlobalVariableSaver(Saver): # if graph is finalized, savers must have already already been added. This happens # in the case of a MonitoredSession self._variables = tf.global_variables() - + # target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list # the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow. - self._variables = [v for v in self._variables if '/target' not in v.name] + self._variables = [v for v in self._variables if "/target" not in v.name] # Using a placeholder to update the variable during restore to avoid memory leak. # Ref: https://github.com/tensorflow/tensorflow/issues/4151 @@ -61,31 +62,80 @@ class GlobalVariableSaver(Saver): save_path = self._saver.save(sess, save_path) return [save_path] + def to_arrays(self, session: Any) -> Dict[str, np.ndarray]: + """ + Save to dictionary of arrays + :param sess: active session + :return: dictionary of arrays + """ + return { + k.name.split(":")[0]: v for k, v in zip(self._variables, session.run(self._variables)) + } + + def from_arrays(self, session: Any, tensors: Any): + """ + Restore from restore_path + :param sess: active session for session-based frameworks (e.g. TF) + :param tensors: {name: array} + """ + # if variable was saved using global network, re-map it to online + # network + # TODO: Can this be more generic so that `global/` and `online/` are not + # hardcoded here? + if isinstance(tensors, dict): + tensors = tensors.items() + + variables = {k.replace("global/", "online/"): v for k, v in tensors} + + # Assign all variables using placeholder + placeholder_dict = { + ph: variables[v.name.split(":")[0]] + for ph, v in zip(self._variable_placeholders, self._variables) + } + session.run(self._variable_update_ops, placeholder_dict) + + def to_string(self, session: Any) -> str: + """ + Save to byte string + :param session: active session + :return: serialized byte string + """ + return pickle.dumps(self.to_arrays(session), protocol=-1) + + def from_string(self, session: Any, string: str): + """ + Restore from byte string + :param session: active session + :param string: byte string to restore from + """ + self.from_arrays(session, pickle.loads(string)) + + def _read_tensors(self, restore_path: str): + """ + Load tensors from a checkpoint + :param restore_path: full path to load checkpoint from. + """ + # We don't use saver.restore() because checkpoint is loaded to online + # network, but if the checkpoint is from the global network, a namespace + # mismatch exists and variable name must be modified before loading. + reader = tf.contrib.framework.load_checkpoint(restore_path) + for var_name, _ in reader.get_variable_to_shape_map().items(): + yield var_name, reader.get_tensor(var_name) + def restore(self, sess: Any, restore_path: str): """ Restore from restore_path :param sess: active session for session-based frameworks (e.g. TF) :param restore_path: full path to load checkpoint from. """ - # We don't use saver.restore() because checkpoint is loaded to online network, but if the checkpoint - # is from the global network, a namespace mismatch exists and variable name must be modified before loading. - variables = dict() - reader = tf.contrib.framework.load_checkpoint(restore_path) - for var_name, _ in reader.get_variable_to_shape_map().items(): - # if variable was saved using global network, re-map it to online network - # TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here? - new_name = var_name.replace('global/', 'online/') - variables[new_name] = reader.get_tensor(var_name) + self.from_arrays(sess, self._read_tensors(restore_path)) - # Assign all variables using placeholder - placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)} - sess.run(self._variable_update_ops, placeholder_dict) - - def merge(self, other: 'Saver'): + def merge(self, other: "Saver"): """ Merge other saver into this saver :param other: saver to be merged into self """ assert isinstance(other, GlobalVariableSaver) self._names.extend(other._names) - # There is nothing else to do because variables must already be part of the global collection. + # There is nothing else to do because variables must already be part of + # the global collection. diff --git a/rl_coach/tests/test_global_variable_saver.py b/rl_coach/tests/test_global_variable_saver.py new file mode 100644 index 0000000..19da034 --- /dev/null +++ b/rl_coach/tests/test_global_variable_saver.py @@ -0,0 +1,79 @@ +import random +import pickle + +import pytest +import tensorflow as tf +import numpy as np + +from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver + + +def random_name(): + return "%032x" % random.randrange(16 ** 32) + + +@pytest.fixture +def name(): + return random_name() + + +@pytest.fixture +def variable(shape, name): + tf.reset_default_graph() + return tf.Variable(tf.zeros(shape), name=name) + + +@pytest.fixture +def shape(): + return (3, 5) + + +def assert_arrays_ones_shape(arrays, shape, name): + assert list(arrays.keys()) == [name] + assert len(arrays) == 1 + assert np.all(list(arrays[name][0]) == np.ones(shape)) + + +@pytest.mark.unit_test +def test_global_variable_saver_to_arrays(variable, name, shape): + with tf.Session() as session: + session.run(tf.global_variables_initializer()) + session.run(variable.assign(tf.ones(shape))) + + saver = GlobalVariableSaver("name") + arrays = saver.to_arrays(session) + assert_arrays_ones_shape(arrays, shape, name) + + +@pytest.mark.unit_test +def test_global_variable_saver_from_arrays(variable, name, shape): + with tf.Session() as session: + session.run(tf.global_variables_initializer()) + + saver = GlobalVariableSaver("name") + saver.from_arrays(session, {name: np.ones(shape)}) + arrays = saver.to_arrays(session) + assert_arrays_ones_shape(arrays, shape, name) + + +@pytest.mark.unit_test +def test_global_variable_saver_to_string(variable, name, shape): + with tf.Session() as session: + session.run(tf.global_variables_initializer()) + session.run(variable.assign(tf.ones(shape))) + + saver = GlobalVariableSaver("name") + string = saver.to_string(session) + arrays = pickle.loads(string) + assert_arrays_ones_shape(arrays, shape, name) + + +@pytest.mark.unit_test +def test_global_variable_saver_from_string(variable, name, shape): + with tf.Session() as session: + session.run(tf.global_variables_initializer()) + + saver = GlobalVariableSaver("name") + saver.from_string(session, pickle.dumps({name: np.ones(shape)}, protocol=-1)) + arrays = saver.to_arrays(session) + assert_arrays_ones_shape(arrays, shape, name)