mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
allow serializing from/to arrays/str from GlobalVariableSaver (#285)
This commit is contained in:
@@ -14,10 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import pickle
|
||||||
from typing import Any, List
|
from typing import Any, List, Dict
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.saver import Saver
|
from rl_coach.saver import Saver
|
||||||
|
|
||||||
@@ -28,10 +29,10 @@ class GlobalVariableSaver(Saver):
|
|||||||
# if graph is finalized, savers must have already already been added. This happens
|
# if graph is finalized, savers must have already already been added. This happens
|
||||||
# in the case of a MonitoredSession
|
# in the case of a MonitoredSession
|
||||||
self._variables = tf.global_variables()
|
self._variables = tf.global_variables()
|
||||||
|
|
||||||
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
|
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
|
||||||
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
|
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
|
||||||
self._variables = [v for v in self._variables if '/target' not in v.name]
|
self._variables = [v for v in self._variables if "/target" not in v.name]
|
||||||
|
|
||||||
# Using a placeholder to update the variable during restore to avoid memory leak.
|
# Using a placeholder to update the variable during restore to avoid memory leak.
|
||||||
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
|
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
|
||||||
@@ -61,31 +62,80 @@ class GlobalVariableSaver(Saver):
|
|||||||
save_path = self._saver.save(sess, save_path)
|
save_path = self._saver.save(sess, save_path)
|
||||||
return [save_path]
|
return [save_path]
|
||||||
|
|
||||||
|
def to_arrays(self, session: Any) -> Dict[str, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Save to dictionary of arrays
|
||||||
|
:param sess: active session
|
||||||
|
:return: dictionary of arrays
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
k.name.split(":")[0]: v for k, v in zip(self._variables, session.run(self._variables))
|
||||||
|
}
|
||||||
|
|
||||||
|
def from_arrays(self, session: Any, tensors: Any):
|
||||||
|
"""
|
||||||
|
Restore from restore_path
|
||||||
|
:param sess: active session for session-based frameworks (e.g. TF)
|
||||||
|
:param tensors: {name: array}
|
||||||
|
"""
|
||||||
|
# if variable was saved using global network, re-map it to online
|
||||||
|
# network
|
||||||
|
# TODO: Can this be more generic so that `global/` and `online/` are not
|
||||||
|
# hardcoded here?
|
||||||
|
if isinstance(tensors, dict):
|
||||||
|
tensors = tensors.items()
|
||||||
|
|
||||||
|
variables = {k.replace("global/", "online/"): v for k, v in tensors}
|
||||||
|
|
||||||
|
# Assign all variables using placeholder
|
||||||
|
placeholder_dict = {
|
||||||
|
ph: variables[v.name.split(":")[0]]
|
||||||
|
for ph, v in zip(self._variable_placeholders, self._variables)
|
||||||
|
}
|
||||||
|
session.run(self._variable_update_ops, placeholder_dict)
|
||||||
|
|
||||||
|
def to_string(self, session: Any) -> str:
|
||||||
|
"""
|
||||||
|
Save to byte string
|
||||||
|
:param session: active session
|
||||||
|
:return: serialized byte string
|
||||||
|
"""
|
||||||
|
return pickle.dumps(self.to_arrays(session), protocol=-1)
|
||||||
|
|
||||||
|
def from_string(self, session: Any, string: str):
|
||||||
|
"""
|
||||||
|
Restore from byte string
|
||||||
|
:param session: active session
|
||||||
|
:param string: byte string to restore from
|
||||||
|
"""
|
||||||
|
self.from_arrays(session, pickle.loads(string))
|
||||||
|
|
||||||
|
def _read_tensors(self, restore_path: str):
|
||||||
|
"""
|
||||||
|
Load tensors from a checkpoint
|
||||||
|
:param restore_path: full path to load checkpoint from.
|
||||||
|
"""
|
||||||
|
# We don't use saver.restore() because checkpoint is loaded to online
|
||||||
|
# network, but if the checkpoint is from the global network, a namespace
|
||||||
|
# mismatch exists and variable name must be modified before loading.
|
||||||
|
reader = tf.contrib.framework.load_checkpoint(restore_path)
|
||||||
|
for var_name, _ in reader.get_variable_to_shape_map().items():
|
||||||
|
yield var_name, reader.get_tensor(var_name)
|
||||||
|
|
||||||
def restore(self, sess: Any, restore_path: str):
|
def restore(self, sess: Any, restore_path: str):
|
||||||
"""
|
"""
|
||||||
Restore from restore_path
|
Restore from restore_path
|
||||||
:param sess: active session for session-based frameworks (e.g. TF)
|
:param sess: active session for session-based frameworks (e.g. TF)
|
||||||
:param restore_path: full path to load checkpoint from.
|
:param restore_path: full path to load checkpoint from.
|
||||||
"""
|
"""
|
||||||
# We don't use saver.restore() because checkpoint is loaded to online network, but if the checkpoint
|
self.from_arrays(sess, self._read_tensors(restore_path))
|
||||||
# is from the global network, a namespace mismatch exists and variable name must be modified before loading.
|
|
||||||
variables = dict()
|
|
||||||
reader = tf.contrib.framework.load_checkpoint(restore_path)
|
|
||||||
for var_name, _ in reader.get_variable_to_shape_map().items():
|
|
||||||
# if variable was saved using global network, re-map it to online network
|
|
||||||
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
|
|
||||||
new_name = var_name.replace('global/', 'online/')
|
|
||||||
variables[new_name] = reader.get_tensor(var_name)
|
|
||||||
|
|
||||||
# Assign all variables using placeholder
|
def merge(self, other: "Saver"):
|
||||||
placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)}
|
|
||||||
sess.run(self._variable_update_ops, placeholder_dict)
|
|
||||||
|
|
||||||
def merge(self, other: 'Saver'):
|
|
||||||
"""
|
"""
|
||||||
Merge other saver into this saver
|
Merge other saver into this saver
|
||||||
:param other: saver to be merged into self
|
:param other: saver to be merged into self
|
||||||
"""
|
"""
|
||||||
assert isinstance(other, GlobalVariableSaver)
|
assert isinstance(other, GlobalVariableSaver)
|
||||||
self._names.extend(other._names)
|
self._names.extend(other._names)
|
||||||
# There is nothing else to do because variables must already be part of the global collection.
|
# There is nothing else to do because variables must already be part of
|
||||||
|
# the global collection.
|
||||||
|
|||||||
79
rl_coach/tests/test_global_variable_saver.py
Normal file
79
rl_coach/tests/test_global_variable_saver.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
|
||||||
|
|
||||||
|
|
||||||
|
def random_name():
|
||||||
|
return "%032x" % random.randrange(16 ** 32)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def name():
|
||||||
|
return random_name()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def variable(shape, name):
|
||||||
|
tf.reset_default_graph()
|
||||||
|
return tf.Variable(tf.zeros(shape), name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def shape():
|
||||||
|
return (3, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_arrays_ones_shape(arrays, shape, name):
|
||||||
|
assert list(arrays.keys()) == [name]
|
||||||
|
assert len(arrays) == 1
|
||||||
|
assert np.all(list(arrays[name][0]) == np.ones(shape))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_global_variable_saver_to_arrays(variable, name, shape):
|
||||||
|
with tf.Session() as session:
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
session.run(variable.assign(tf.ones(shape)))
|
||||||
|
|
||||||
|
saver = GlobalVariableSaver("name")
|
||||||
|
arrays = saver.to_arrays(session)
|
||||||
|
assert_arrays_ones_shape(arrays, shape, name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_global_variable_saver_from_arrays(variable, name, shape):
|
||||||
|
with tf.Session() as session:
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
saver = GlobalVariableSaver("name")
|
||||||
|
saver.from_arrays(session, {name: np.ones(shape)})
|
||||||
|
arrays = saver.to_arrays(session)
|
||||||
|
assert_arrays_ones_shape(arrays, shape, name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_global_variable_saver_to_string(variable, name, shape):
|
||||||
|
with tf.Session() as session:
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
session.run(variable.assign(tf.ones(shape)))
|
||||||
|
|
||||||
|
saver = GlobalVariableSaver("name")
|
||||||
|
string = saver.to_string(session)
|
||||||
|
arrays = pickle.loads(string)
|
||||||
|
assert_arrays_ones_shape(arrays, shape, name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_global_variable_saver_from_string(variable, name, shape):
|
||||||
|
with tf.Session() as session:
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
saver = GlobalVariableSaver("name")
|
||||||
|
saver.from_string(session, pickle.dumps({name: np.ones(shape)}, protocol=-1))
|
||||||
|
arrays = saver.to_arrays(session)
|
||||||
|
assert_arrays_ones_shape(arrays, shape, name)
|
||||||
Reference in New Issue
Block a user