mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
allow serializing from/to arrays/str from GlobalVariableSaver (#285)
This commit is contained in:
@@ -14,10 +14,11 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from typing import Any, List
|
||||
import pickle
|
||||
from typing import Any, List, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.saver import Saver
|
||||
|
||||
@@ -28,10 +29,10 @@ class GlobalVariableSaver(Saver):
|
||||
# if graph is finalized, savers must have already already been added. This happens
|
||||
# in the case of a MonitoredSession
|
||||
self._variables = tf.global_variables()
|
||||
|
||||
|
||||
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
|
||||
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
|
||||
self._variables = [v for v in self._variables if '/target' not in v.name]
|
||||
self._variables = [v for v in self._variables if "/target" not in v.name]
|
||||
|
||||
# Using a placeholder to update the variable during restore to avoid memory leak.
|
||||
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
|
||||
@@ -61,31 +62,80 @@ class GlobalVariableSaver(Saver):
|
||||
save_path = self._saver.save(sess, save_path)
|
||||
return [save_path]
|
||||
|
||||
def to_arrays(self, session: Any) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Save to dictionary of arrays
|
||||
:param sess: active session
|
||||
:return: dictionary of arrays
|
||||
"""
|
||||
return {
|
||||
k.name.split(":")[0]: v for k, v in zip(self._variables, session.run(self._variables))
|
||||
}
|
||||
|
||||
def from_arrays(self, session: Any, tensors: Any):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param tensors: {name: array}
|
||||
"""
|
||||
# if variable was saved using global network, re-map it to online
|
||||
# network
|
||||
# TODO: Can this be more generic so that `global/` and `online/` are not
|
||||
# hardcoded here?
|
||||
if isinstance(tensors, dict):
|
||||
tensors = tensors.items()
|
||||
|
||||
variables = {k.replace("global/", "online/"): v for k, v in tensors}
|
||||
|
||||
# Assign all variables using placeholder
|
||||
placeholder_dict = {
|
||||
ph: variables[v.name.split(":")[0]]
|
||||
for ph, v in zip(self._variable_placeholders, self._variables)
|
||||
}
|
||||
session.run(self._variable_update_ops, placeholder_dict)
|
||||
|
||||
def to_string(self, session: Any) -> str:
|
||||
"""
|
||||
Save to byte string
|
||||
:param session: active session
|
||||
:return: serialized byte string
|
||||
"""
|
||||
return pickle.dumps(self.to_arrays(session), protocol=-1)
|
||||
|
||||
def from_string(self, session: Any, string: str):
|
||||
"""
|
||||
Restore from byte string
|
||||
:param session: active session
|
||||
:param string: byte string to restore from
|
||||
"""
|
||||
self.from_arrays(session, pickle.loads(string))
|
||||
|
||||
def _read_tensors(self, restore_path: str):
|
||||
"""
|
||||
Load tensors from a checkpoint
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
# We don't use saver.restore() because checkpoint is loaded to online
|
||||
# network, but if the checkpoint is from the global network, a namespace
|
||||
# mismatch exists and variable name must be modified before loading.
|
||||
reader = tf.contrib.framework.load_checkpoint(restore_path)
|
||||
for var_name, _ in reader.get_variable_to_shape_map().items():
|
||||
yield var_name, reader.get_tensor(var_name)
|
||||
|
||||
def restore(self, sess: Any, restore_path: str):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
# We don't use saver.restore() because checkpoint is loaded to online network, but if the checkpoint
|
||||
# is from the global network, a namespace mismatch exists and variable name must be modified before loading.
|
||||
variables = dict()
|
||||
reader = tf.contrib.framework.load_checkpoint(restore_path)
|
||||
for var_name, _ in reader.get_variable_to_shape_map().items():
|
||||
# if variable was saved using global network, re-map it to online network
|
||||
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
|
||||
new_name = var_name.replace('global/', 'online/')
|
||||
variables[new_name] = reader.get_tensor(var_name)
|
||||
self.from_arrays(sess, self._read_tensors(restore_path))
|
||||
|
||||
# Assign all variables using placeholder
|
||||
placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)}
|
||||
sess.run(self._variable_update_ops, placeholder_dict)
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
def merge(self, other: "Saver"):
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
assert isinstance(other, GlobalVariableSaver)
|
||||
self._names.extend(other._names)
|
||||
# There is nothing else to do because variables must already be part of the global collection.
|
||||
# There is nothing else to do because variables must already be part of
|
||||
# the global collection.
|
||||
|
||||
79
rl_coach/tests/test_global_variable_saver.py
Normal file
79
rl_coach/tests/test_global_variable_saver.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import random
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
|
||||
|
||||
|
||||
def random_name():
|
||||
return "%032x" % random.randrange(16 ** 32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def name():
|
||||
return random_name()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def variable(shape, name):
|
||||
tf.reset_default_graph()
|
||||
return tf.Variable(tf.zeros(shape), name=name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shape():
|
||||
return (3, 5)
|
||||
|
||||
|
||||
def assert_arrays_ones_shape(arrays, shape, name):
|
||||
assert list(arrays.keys()) == [name]
|
||||
assert len(arrays) == 1
|
||||
assert np.all(list(arrays[name][0]) == np.ones(shape))
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_global_variable_saver_to_arrays(variable, name, shape):
|
||||
with tf.Session() as session:
|
||||
session.run(tf.global_variables_initializer())
|
||||
session.run(variable.assign(tf.ones(shape)))
|
||||
|
||||
saver = GlobalVariableSaver("name")
|
||||
arrays = saver.to_arrays(session)
|
||||
assert_arrays_ones_shape(arrays, shape, name)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_global_variable_saver_from_arrays(variable, name, shape):
|
||||
with tf.Session() as session:
|
||||
session.run(tf.global_variables_initializer())
|
||||
|
||||
saver = GlobalVariableSaver("name")
|
||||
saver.from_arrays(session, {name: np.ones(shape)})
|
||||
arrays = saver.to_arrays(session)
|
||||
assert_arrays_ones_shape(arrays, shape, name)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_global_variable_saver_to_string(variable, name, shape):
|
||||
with tf.Session() as session:
|
||||
session.run(tf.global_variables_initializer())
|
||||
session.run(variable.assign(tf.ones(shape)))
|
||||
|
||||
saver = GlobalVariableSaver("name")
|
||||
string = saver.to_string(session)
|
||||
arrays = pickle.loads(string)
|
||||
assert_arrays_ones_shape(arrays, shape, name)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_global_variable_saver_from_string(variable, name, shape):
|
||||
with tf.Session() as session:
|
||||
session.run(tf.global_variables_initializer())
|
||||
|
||||
saver = GlobalVariableSaver("name")
|
||||
saver.from_string(session, pickle.dumps({name: np.ones(shape)}, protocol=-1))
|
||||
arrays = saver.to_arrays(session)
|
||||
assert_arrays_ones_shape(arrays, shape, name)
|
||||
Reference in New Issue
Block a user