mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Tf checkpointing using saver mechanism (#134)
This commit is contained in:
committed by
Gal Leibovich
parent
dd18959e53
commit
16cdd9a9c1
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
from rl_coach.saver import SaverCollection
|
||||
@@ -645,8 +646,10 @@ class TensorFlowArchitecture(Architecture):
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: checkpoint collection for the network
|
||||
"""
|
||||
# TODO implement returning checkpoints for tensorflow
|
||||
return SaverCollection()
|
||||
savers = SaverCollection()
|
||||
if not self.distributed_training:
|
||||
savers.add(GlobalVariableSaver(self.name))
|
||||
return savers
|
||||
|
||||
|
||||
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||
|
||||
58
rl_coach/architectures/tensorflow_components/savers.py
Normal file
58
rl_coach/architectures/tensorflow_components/savers.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import Any, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.saver import Saver
|
||||
|
||||
|
||||
class GlobalVariableSaver(Saver):
|
||||
def __init__(self, name):
|
||||
self._names = [name]
|
||||
# if graph is finalized, savers must have already already been added. This happens
|
||||
# in the case of a MonitoredSession
|
||||
self._variables = tf.global_variables()
|
||||
self._saver = tf.train.Saver(self._variables)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||
"""
|
||||
return "" # use empty string for global file
|
||||
|
||||
def save(self, sess: None, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session
|
||||
:param save_path: full path to save checkpoint (typically directory plus checkpoint prefix plus self.path)
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
save_path = self._saver.save(sess, save_path)
|
||||
return [save_path]
|
||||
|
||||
def restore(self, sess: Any, restore_path: str):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
# We don't use saver.restore() because checkpoint is loaded to online network, but if the checkpoint
|
||||
# is from the global network, a namespace mismatch exists and variable name must be modified before loading.
|
||||
variables = dict()
|
||||
reader = tf.contrib.framework.load_checkpoint(restore_path)
|
||||
for var_name, _ in reader.get_variable_to_shape_map().items():
|
||||
# if variable was saved using global network, re-map it to online network
|
||||
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
|
||||
new_name = var_name.replace('global/', 'online/')
|
||||
variables[new_name] = reader.get_tensor(var_name)
|
||||
# Assign all variables
|
||||
sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables])
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
assert isinstance(other, GlobalVariableSaver)
|
||||
self._names.extend(other._names)
|
||||
# There is nothing else to do because variables must already be part of the global collection.
|
||||
Reference in New Issue
Block a user