mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Tf checkpointing using saver mechanism (#134)
This commit is contained in:
committed by
Gal Leibovich
parent
dd18959e53
commit
16cdd9a9c1
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
from rl_coach.saver import SaverCollection
|
||||
@@ -645,8 +646,10 @@ class TensorFlowArchitecture(Architecture):
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: checkpoint collection for the network
|
||||
"""
|
||||
# TODO implement returning checkpoints for tensorflow
|
||||
return SaverCollection()
|
||||
savers = SaverCollection()
|
||||
if not self.distributed_training:
|
||||
savers.add(GlobalVariableSaver(self.name))
|
||||
return savers
|
||||
|
||||
|
||||
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||
|
||||
58
rl_coach/architectures/tensorflow_components/savers.py
Normal file
58
rl_coach/architectures/tensorflow_components/savers.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import Any, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.saver import Saver
|
||||
|
||||
|
||||
class GlobalVariableSaver(Saver):
|
||||
def __init__(self, name):
|
||||
self._names = [name]
|
||||
# if graph is finalized, savers must have already already been added. This happens
|
||||
# in the case of a MonitoredSession
|
||||
self._variables = tf.global_variables()
|
||||
self._saver = tf.train.Saver(self._variables)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""
|
||||
Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
|
||||
"""
|
||||
return "" # use empty string for global file
|
||||
|
||||
def save(self, sess: None, save_path: str) -> List[str]:
|
||||
"""
|
||||
Save to save_path
|
||||
:param sess: active session
|
||||
:param save_path: full path to save checkpoint (typically directory plus checkpoint prefix plus self.path)
|
||||
:return: list of all saved paths
|
||||
"""
|
||||
save_path = self._saver.save(sess, save_path)
|
||||
return [save_path]
|
||||
|
||||
def restore(self, sess: Any, restore_path: str):
|
||||
"""
|
||||
Restore from restore_path
|
||||
:param sess: active session for session-based frameworks (e.g. TF)
|
||||
:param restore_path: full path to load checkpoint from.
|
||||
"""
|
||||
# We don't use saver.restore() because checkpoint is loaded to online network, but if the checkpoint
|
||||
# is from the global network, a namespace mismatch exists and variable name must be modified before loading.
|
||||
variables = dict()
|
||||
reader = tf.contrib.framework.load_checkpoint(restore_path)
|
||||
for var_name, _ in reader.get_variable_to_shape_map().items():
|
||||
# if variable was saved using global network, re-map it to online network
|
||||
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
|
||||
new_name = var_name.replace('global/', 'online/')
|
||||
variables[new_name] = reader.get_tensor(var_name)
|
||||
# Assign all variables
|
||||
sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables])
|
||||
|
||||
def merge(self, other: 'Saver'):
|
||||
"""
|
||||
Merge other saver into this saver
|
||||
:param other: saver to be merged into self
|
||||
"""
|
||||
assert isinstance(other, GlobalVariableSaver)
|
||||
self._names.extend(other._names)
|
||||
# There is nothing else to do because variables must already be part of the global collection.
|
||||
@@ -232,19 +232,11 @@ class GraphManager(object):
|
||||
# set the session for all the modules
|
||||
self.set_session(self.sess)
|
||||
else:
|
||||
self.variables_to_restore = tf.global_variables()
|
||||
# self.variables_to_restore = [v for v in self.variables_to_restore if '/online' in v.name] TODO: is this necessary?
|
||||
self.checkpoint_saver = tf.train.Saver(self.variables_to_restore)
|
||||
|
||||
# regular session
|
||||
self.sess = tf.Session(config=config)
|
||||
|
||||
# set the session for all the modules
|
||||
self.set_session(self.sess)
|
||||
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
|
||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||
self.save_graph()
|
||||
@@ -254,11 +246,6 @@ class GraphManager(object):
|
||||
Call set_session to initialize parameters and construct checkpoint_saver
|
||||
"""
|
||||
self.set_session(sess=None) # Initialize all modules
|
||||
self.checkpoint_saver = SaverCollection()
|
||||
for level in self.level_managers:
|
||||
self.checkpoint_saver.update(level.collect_savers())
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
def create_session(self, task_parameters: TaskParameters):
|
||||
if task_parameters.framework_type == Frameworks.tensorflow:
|
||||
@@ -268,6 +255,13 @@ class GraphManager(object):
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
|
||||
|
||||
# Create parameter saver
|
||||
self.checkpoint_saver = SaverCollection()
|
||||
for level in self.level_managers:
|
||||
self.checkpoint_saver.update(level.collect_savers())
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
def save_graph(self) -> None:
|
||||
"""
|
||||
Save the TF graph to a protobuf description file in the experiment directory
|
||||
@@ -566,16 +560,12 @@ class GraphManager(object):
|
||||
|
||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||
if self.task_parameters.checkpoint_restore_dir:
|
||||
|
||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._restore_checkpoint_tf(checkpoint.model_checkpoint_path)
|
||||
elif self.task_parameters.framework_type == Frameworks.mxnet:
|
||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||
if checkpoint is None:
|
||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||
|
||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||
|
||||
@@ -598,10 +588,7 @@ class GraphManager(object):
|
||||
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
||||
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
|
||||
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
||||
if self.checkpoint_saver is not None:
|
||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||
else:
|
||||
saved_checkpoint_path = "<Not Saved>"
|
||||
else:
|
||||
saved_checkpoint_path = checkpoint_path
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ class SaverCollection(object):
|
||||
"""
|
||||
paths = list()
|
||||
for saver in self:
|
||||
paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path)))
|
||||
paths.extend(saver.save(sess, self._full_path(save_path, saver)))
|
||||
return paths
|
||||
|
||||
def restore(self, sess: Any, restore_path: str) -> None:
|
||||
@@ -98,8 +98,7 @@ class SaverCollection(object):
|
||||
:param restore_path: path for restoring checkpoint using savers.
|
||||
"""
|
||||
for saver in self:
|
||||
restore_path = "{}.{}".format(restore_path, saver.path)
|
||||
saver.restore(sess, restore_path)
|
||||
saver.restore(sess, self._full_path(restore_path, saver))
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
@@ -108,5 +107,16 @@ class SaverCollection(object):
|
||||
"""
|
||||
return (v for v in self._saver_dict.values())
|
||||
|
||||
@staticmethod
|
||||
def _full_path(path_prefix: str, saver: Saver) -> str:
|
||||
"""
|
||||
Concatenates path of the saver to parent prefix to create full save path
|
||||
:param path_prefix: prefix of the path
|
||||
:param saver: saver object to get unique path extension from
|
||||
:return: full path
|
||||
"""
|
||||
if saver.path == "":
|
||||
return path_prefix
|
||||
return "{}.{}".format(path_prefix, saver.path)
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from rl_coach import utils
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_checkpoint_state_default():
|
||||
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext']
|
||||
checkpoint_state = utils.get_checkpoint_state(files)
|
||||
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
|
||||
assert checkpoint_state.all_model_checkpoint_paths == [f[:-4] for f in sorted(files)]
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_checkpoint_state_custom():
|
||||
files = ['prefix.4.test.ckpt.ext', 'prefix.2.test.ckpt.ext', 'prefix.3.test.ckpt.ext', 'prefix.1.test.ckpt.ext']
|
||||
assert len(utils.get_checkpoint_state(files).all_model_checkpoint_paths) == 0 # doesn't match the default pattern
|
||||
checkpoint_state = utils.get_checkpoint_state(files, filename_pattern=r'([0-9]+)[^0-9].*?\.ckpt')
|
||||
assert checkpoint_state.model_checkpoint_path == '4.test.ckpt'
|
||||
assert checkpoint_state.all_model_checkpoint_paths == [f[7:-4] for f in sorted(files)]
|
||||
def test_get_checkpoint_state():
|
||||
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext']
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
[open(os.path.join(temp_dir, fn), 'a').close() for fn in files]
|
||||
checkpoint_state = utils.get_checkpoint_state(temp_dir)
|
||||
assert checkpoint_state.model_checkpoint_path == os.path.join(temp_dir, '4.test.ckpt')
|
||||
assert checkpoint_state.all_model_checkpoint_paths == \
|
||||
[os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])]
|
||||
|
||||
|
||||
@@ -574,24 +574,30 @@ class CheckpointState(object):
|
||||
return str(self._checkpoints)
|
||||
|
||||
|
||||
COACH_CHECKPOINT_PATTERN = r'\A([0-9]+)[^0-9].*?\.ckpt'
|
||||
COACH_CHECKPOINT_PATTERN = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
|
||||
|
||||
|
||||
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
|
||||
CheckpointState:
|
||||
Union[CheckpointState, None]:
|
||||
"""
|
||||
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(1)) to sort
|
||||
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(2) or group(4) to sort
|
||||
the checkpoint names and find the latest checkpoint
|
||||
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
|
||||
:param filename_pattern: regex pattern for checkpoint filenames
|
||||
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names
|
||||
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names. If no matching
|
||||
files are found, returns None.
|
||||
"""
|
||||
prog = re.compile(filename_pattern)
|
||||
checkpoints = dict()
|
||||
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
|
||||
for name in filenames:
|
||||
m = prog.search(name)
|
||||
if m is not None and m.group(1) is not None:
|
||||
if m is not None and (m.group(2) is not None or m.group(4) is not None):
|
||||
if m.group(2) is not None and m.group(4) is not None:
|
||||
assert m.group(2) == m.group(4)
|
||||
checkpoint_count = int(m.group(2) if m.group(2) is not None else m.group(4))
|
||||
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
|
||||
checkpoints[int(m.group(1))] = full_path
|
||||
checkpoints[checkpoint_count] = full_path
|
||||
if len(checkpoints) == 0:
|
||||
return None
|
||||
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])
|
||||
|
||||
Reference in New Issue
Block a user