1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

allow serializing from/to arrays/str from GlobalVariableSaver (#285)

This commit is contained in:
Zach Dwiel
2019-04-04 11:09:19 -04:00
committed by GitHub
parent cdb8d9e518
commit 2291cee2c6
2 changed files with 148 additions and 19 deletions

View File

@@ -14,10 +14,11 @@
# limitations under the License.
#
from typing import Any, List
import pickle
from typing import Any, List, Dict
import tensorflow as tf
import numpy as np
from rl_coach.saver import Saver
@@ -31,7 +32,7 @@ class GlobalVariableSaver(Saver):
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
self._variables = [v for v in self._variables if '/target' not in v.name]
self._variables = [v for v in self._variables if "/target" not in v.name]
# Using a placeholder to update the variable during restore to avoid memory leak.
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
@@ -61,31 +62,80 @@ class GlobalVariableSaver(Saver):
save_path = self._saver.save(sess, save_path)
return [save_path]
def to_arrays(self, session: Any) -> Dict[str, np.ndarray]:
"""
Save to dictionary of arrays
:param sess: active session
:return: dictionary of arrays
"""
return {
k.name.split(":")[0]: v for k, v in zip(self._variables, session.run(self._variables))
}
def from_arrays(self, session: Any, tensors: Any):
"""
Restore from restore_path
:param sess: active session for session-based frameworks (e.g. TF)
:param tensors: {name: array}
"""
# if variable was saved using global network, re-map it to online
# network
# TODO: Can this be more generic so that `global/` and `online/` are not
# hardcoded here?
if isinstance(tensors, dict):
tensors = tensors.items()
variables = {k.replace("global/", "online/"): v for k, v in tensors}
# Assign all variables using placeholder
placeholder_dict = {
ph: variables[v.name.split(":")[0]]
for ph, v in zip(self._variable_placeholders, self._variables)
}
session.run(self._variable_update_ops, placeholder_dict)
def to_string(self, session: Any) -> str:
"""
Save to byte string
:param session: active session
:return: serialized byte string
"""
return pickle.dumps(self.to_arrays(session), protocol=-1)
def from_string(self, session: Any, string: str):
"""
Restore from byte string
:param session: active session
:param string: byte string to restore from
"""
self.from_arrays(session, pickle.loads(string))
def _read_tensors(self, restore_path: str):
"""
Load tensors from a checkpoint
:param restore_path: full path to load checkpoint from.
"""
# We don't use saver.restore() because checkpoint is loaded to online
# network, but if the checkpoint is from the global network, a namespace
# mismatch exists and variable name must be modified before loading.
reader = tf.contrib.framework.load_checkpoint(restore_path)
for var_name, _ in reader.get_variable_to_shape_map().items():
yield var_name, reader.get_tensor(var_name)
def restore(self, sess: Any, restore_path: str):
"""
Restore from restore_path
:param sess: active session for session-based frameworks (e.g. TF)
:param restore_path: full path to load checkpoint from.
"""
# We don't use saver.restore() because checkpoint is loaded to online network, but if the checkpoint
# is from the global network, a namespace mismatch exists and variable name must be modified before loading.
variables = dict()
reader = tf.contrib.framework.load_checkpoint(restore_path)
for var_name, _ in reader.get_variable_to_shape_map().items():
# if variable was saved using global network, re-map it to online network
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
new_name = var_name.replace('global/', 'online/')
variables[new_name] = reader.get_tensor(var_name)
self.from_arrays(sess, self._read_tensors(restore_path))
# Assign all variables using placeholder
placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)}
sess.run(self._variable_update_ops, placeholder_dict)
def merge(self, other: 'Saver'):
def merge(self, other: "Saver"):
"""
Merge other saver into this saver
:param other: saver to be merged into self
"""
assert isinstance(other, GlobalVariableSaver)
self._names.extend(other._names)
# There is nothing else to do because variables must already be part of the global collection.
# There is nothing else to do because variables must already be part of
# the global collection.

View File

@@ -0,0 +1,79 @@
import random
import pickle
import pytest
import tensorflow as tf
import numpy as np
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
def random_name():
return "%032x" % random.randrange(16 ** 32)
@pytest.fixture
def name():
return random_name()
@pytest.fixture
def variable(shape, name):
tf.reset_default_graph()
return tf.Variable(tf.zeros(shape), name=name)
@pytest.fixture
def shape():
return (3, 5)
def assert_arrays_ones_shape(arrays, shape, name):
assert list(arrays.keys()) == [name]
assert len(arrays) == 1
assert np.all(list(arrays[name][0]) == np.ones(shape))
@pytest.mark.unit_test
def test_global_variable_saver_to_arrays(variable, name, shape):
with tf.Session() as session:
session.run(tf.global_variables_initializer())
session.run(variable.assign(tf.ones(shape)))
saver = GlobalVariableSaver("name")
arrays = saver.to_arrays(session)
assert_arrays_ones_shape(arrays, shape, name)
@pytest.mark.unit_test
def test_global_variable_saver_from_arrays(variable, name, shape):
with tf.Session() as session:
session.run(tf.global_variables_initializer())
saver = GlobalVariableSaver("name")
saver.from_arrays(session, {name: np.ones(shape)})
arrays = saver.to_arrays(session)
assert_arrays_ones_shape(arrays, shape, name)
@pytest.mark.unit_test
def test_global_variable_saver_to_string(variable, name, shape):
with tf.Session() as session:
session.run(tf.global_variables_initializer())
session.run(variable.assign(tf.ones(shape)))
saver = GlobalVariableSaver("name")
string = saver.to_string(session)
arrays = pickle.loads(string)
assert_arrays_ones_shape(arrays, shape, name)
@pytest.mark.unit_test
def test_global_variable_saver_from_string(variable, name, shape):
with tf.Session() as session:
session.run(tf.global_variables_initializer())
saver = GlobalVariableSaver("name")
saver.from_string(session, pickle.dumps({name: np.ones(shape)}, protocol=-1))
arrays = saver.to_arrays(session)
assert_arrays_ones_shape(arrays, shape, name)