1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Tf checkpointing using saver mechanism (#134)

This commit is contained in:
Sina Afrooze
2018-11-22 04:08:10 -08:00
committed by Gal Leibovich
parent dd18959e53
commit 16cdd9a9c1
6 changed files with 110 additions and 50 deletions

View File

@@ -21,6 +21,7 @@ import numpy as np
import tensorflow as tf
from rl_coach.architectures.architecture import Architecture
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.core_types import GradientClippingMethod
from rl_coach.saver import SaverCollection
@@ -645,8 +646,10 @@ class TensorFlowArchitecture(Architecture):
(e.g. could be name of level manager plus name of agent)
:return: checkpoint collection for the network
"""
# TODO implement returning checkpoints for tensorflow
return SaverCollection()
savers = SaverCollection()
if not self.distributed_training:
savers.add(GlobalVariableSaver(self.name))
return savers
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:

View File

@@ -0,0 +1,58 @@
from typing import Any, List
import tensorflow as tf
from rl_coach.saver import Saver
class GlobalVariableSaver(Saver):
def __init__(self, name):
self._names = [name]
# if graph is finalized, savers must have already already been added. This happens
# in the case of a MonitoredSession
self._variables = tf.global_variables()
self._saver = tf.train.Saver(self._variables)
@property
def path(self):
"""
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
"""
return "" # use empty string for global file
def save(self, sess: None, save_path: str) -> List[str]:
"""
Save to save_path
:param sess: active session
:param save_path: full path to save checkpoint (typically directory plus checkpoint prefix plus self.path)
:return: list of all saved paths
"""
save_path = self._saver.save(sess, save_path)
return [save_path]
def restore(self, sess: Any, restore_path: str):
"""
Restore from restore_path
:param sess: active session for session-based frameworks (e.g. TF)
:param restore_path: full path to load checkpoint from.
"""
# We don't use saver.restore() because checkpoint is loaded to online network, but if the checkpoint
# is from the global network, a namespace mismatch exists and variable name must be modified before loading.
variables = dict()
reader = tf.contrib.framework.load_checkpoint(restore_path)
for var_name, _ in reader.get_variable_to_shape_map().items():
# if variable was saved using global network, re-map it to online network
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
new_name = var_name.replace('global/', 'online/')
variables[new_name] = reader.get_tensor(var_name)
# Assign all variables
sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables])
def merge(self, other: 'Saver'):
"""
Merge other saver into this saver
:param other: saver to be merged into self
"""
assert isinstance(other, GlobalVariableSaver)
self._names.extend(other._names)
# There is nothing else to do because variables must already be part of the global collection.

View File

@@ -232,19 +232,11 @@ class GraphManager(object):
# set the session for all the modules
self.set_session(self.sess)
else:
self.variables_to_restore = tf.global_variables()
# self.variables_to_restore = [v for v in self.variables_to_restore if '/online' in v.name] TODO: is this necessary?
self.checkpoint_saver = tf.train.Saver(self.variables_to_restore)
# regular session
self.sess = tf.Session(config=config)
# set the session for all the modules
self.set_session(self.sess)
# restore from checkpoint if given
self.restore_checkpoint()
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
self.save_graph()
@@ -254,11 +246,6 @@ class GraphManager(object):
Call set_session to initialize parameters and construct checkpoint_saver
"""
self.set_session(sess=None) # Initialize all modules
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())
# restore from checkpoint if given
self.restore_checkpoint()
def create_session(self, task_parameters: TaskParameters):
if task_parameters.framework_type == Frameworks.tensorflow:
@@ -268,6 +255,13 @@ class GraphManager(object):
else:
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
# Create parameter saver
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())
# restore from checkpoint if given
self.restore_checkpoint()
def save_graph(self) -> None:
"""
Save the TF graph to a protobuf description file in the experiment directory
@@ -566,16 +560,12 @@ class GraphManager(object):
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if self.task_parameters.checkpoint_restore_dir:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
if self.task_parameters.framework_type == Frameworks.tensorflow:
self._restore_checkpoint_tf(checkpoint.model_checkpoint_path)
elif self.task_parameters.framework_type == Frameworks.mxnet:
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
if checkpoint is None:
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else:
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
@@ -598,10 +588,7 @@ class GraphManager(object):
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
if not isinstance(self.task_parameters, DistributedTaskParameters):
if self.checkpoint_saver is not None:
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else:
saved_checkpoint_path = "<Not Saved>"
else:
saved_checkpoint_path = checkpoint_path

View File

@@ -88,7 +88,7 @@ class SaverCollection(object):
"""
paths = list()
for saver in self:
paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path)))
paths.extend(saver.save(sess, self._full_path(save_path, saver)))
return paths
def restore(self, sess: Any, restore_path: str) -> None:
@@ -98,8 +98,7 @@ class SaverCollection(object):
:param restore_path: path for restoring checkpoint using savers.
"""
for saver in self:
restore_path = "{}.{}".format(restore_path, saver.path)
saver.restore(sess, restore_path)
saver.restore(sess, self._full_path(restore_path, saver))
def __iter__(self):
"""
@@ -108,5 +107,16 @@ class SaverCollection(object):
"""
return (v for v in self._saver_dict.values())
@staticmethod
def _full_path(path_prefix: str, saver: Saver) -> str:
"""
Concatenates path of the saver to parent prefix to create full save path
:param path_prefix: prefix of the path
:param saver: saver object to get unique path extension from
:return: full path
"""
if saver.path == "":
return path_prefix
return "{}.{}".format(path_prefix, saver.path)

View File

@@ -1,21 +1,17 @@
import os
import pytest
import tempfile
from rl_coach import utils
@pytest.mark.unit_test
def test_get_checkpoint_state_default():
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext']
checkpoint_state = utils.get_checkpoint_state(files)
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
assert checkpoint_state.all_model_checkpoint_paths == [f[:-4] for f in sorted(files)]
@pytest.mark.unit_test
def test_get_checkpoint_state_custom():
files = ['prefix.4.test.ckpt.ext', 'prefix.2.test.ckpt.ext', 'prefix.3.test.ckpt.ext', 'prefix.1.test.ckpt.ext']
assert len(utils.get_checkpoint_state(files).all_model_checkpoint_paths) == 0 # doesn't match the default pattern
checkpoint_state = utils.get_checkpoint_state(files, filename_pattern=r'([0-9]+)[^0-9].*?\.ckpt')
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
assert checkpoint_state.all_model_checkpoint_paths == [f[7:-4] for f in sorted(files)]
def test_get_checkpoint_state():
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext']
with tempfile.TemporaryDirectory() as temp_dir:
[open(os.path.join(temp_dir, fn), 'a').close() for fn in files]
checkpoint_state = utils.get_checkpoint_state(temp_dir)
assert checkpoint_state.model_checkpoint_path == os.path.join(temp_dir, '4.test.ckpt')
assert checkpoint_state.all_model_checkpoint_paths == \
[os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])]

View File

@@ -574,24 +574,30 @@ class CheckpointState(object):
return str(self._checkpoints)
COACH_CHECKPOINT_PATTERN = r'\A([0-9]+)[^0-9].*?\.ckpt'
COACH_CHECKPOINT_PATTERN = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
CheckpointState:
Union[CheckpointState, None]:
"""
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(1)) to sort
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(2) or group(4) to sort
the checkpoint names and find the latest checkpoint
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
:param filename_pattern: regex pattern for checkpoint filenames
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names. If no matching
files are found, returns None.
"""
prog = re.compile(filename_pattern)
checkpoints = dict()
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
for name in filenames:
m = prog.search(name)
if m is not None and m.group(1) is not None:
if m is not None and (m.group(2) is not None or m.group(4) is not None):
if m.group(2) is not None and m.group(4) is not None:
assert m.group(2) == m.group(4)
checkpoint_count = int(m.group(2) if m.group(2) is not None else m.group(4))
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
checkpoints[int(m.group(1))] = full_path
checkpoints[checkpoint_count] = full_path
if len(checkpoints) == 0:
return None
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])