1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Tf checkpointing using saver mechanism (#134)

This commit is contained in:
Sina Afrooze
2018-11-22 04:08:10 -08:00
committed by Gal Leibovich
parent dd18959e53
commit 16cdd9a9c1
6 changed files with 110 additions and 50 deletions

View File

@@ -0,0 +1,58 @@
from typing import Any, List
import tensorflow as tf
from rl_coach.saver import Saver
class GlobalVariableSaver(Saver):
def __init__(self, name):
self._names = [name]
# if graph is finalized, savers must have already already been added. This happens
# in the case of a MonitoredSession
self._variables = tf.global_variables()
self._saver = tf.train.Saver(self._variables)
@property
def path(self):
"""
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
"""
return "" # use empty string for global file
def save(self, sess: None, save_path: str) -> List[str]:
"""
Save to save_path
:param sess: active session
:param save_path: full path to save checkpoint (typically directory plus checkpoint prefix plus self.path)
:return: list of all saved paths
"""
save_path = self._saver.save(sess, save_path)
return [save_path]
def restore(self, sess: Any, restore_path: str):
"""
Restore from restore_path
:param sess: active session for session-based frameworks (e.g. TF)
:param restore_path: full path to load checkpoint from.
"""
# We don't use saver.restore() because checkpoint is loaded to online network, but if the checkpoint
# is from the global network, a namespace mismatch exists and variable name must be modified before loading.
variables = dict()
reader = tf.contrib.framework.load_checkpoint(restore_path)
for var_name, _ in reader.get_variable_to_shape_map().items():
# if variable was saved using global network, re-map it to online network
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
new_name = var_name.replace('global/', 'online/')
variables[new_name] = reader.get_tensor(var_name)
# Assign all variables
sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables])
def merge(self, other: 'Saver'):
"""
Merge other saver into this saver
:param other: saver to be merged into self
"""
assert isinstance(other, GlobalVariableSaver)
self._names.extend(other._names)
# There is nothing else to do because variables must already be part of the global collection.