diff --git a/rl_coach/architectures/tensorflow_components/architecture.py b/rl_coach/architectures/tensorflow_components/architecture.py index 03ad98c..3a3802f 100644 --- a/rl_coach/architectures/tensorflow_components/architecture.py +++ b/rl_coach/architectures/tensorflow_components/architecture.py @@ -21,6 +21,7 @@ import numpy as np import tensorflow as tf from rl_coach.architectures.architecture import Architecture +from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters from rl_coach.core_types import GradientClippingMethod from rl_coach.saver import SaverCollection @@ -645,8 +646,10 @@ class TensorFlowArchitecture(Architecture): (e.g. could be name of level manager plus name of agent) :return: checkpoint collection for the network """ - # TODO implement returning checkpoints for tensorflow - return SaverCollection() + savers = SaverCollection() + if not self.distributed_training: + savers.add(GlobalVariableSaver(self.name)) + return savers def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None: diff --git a/rl_coach/architectures/tensorflow_components/savers.py b/rl_coach/architectures/tensorflow_components/savers.py new file mode 100644 index 0000000..bb3ab19 --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/savers.py @@ -0,0 +1,58 @@ +from typing import Any, List + +import tensorflow as tf + +from rl_coach.saver import Saver + + +class GlobalVariableSaver(Saver): + def __init__(self, name): + self._names = [name] + # if graph is finalized, savers must have already already been added. This happens + # in the case of a MonitoredSession + self._variables = tf.global_variables() + self._saver = tf.train.Saver(self._variables) + + @property + def path(self): + """ + Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able. + """ + return "" # use empty string for global file + + def save(self, sess: None, save_path: str) -> List[str]: + """ + Save to save_path + :param sess: active session + :param save_path: full path to save checkpoint (typically directory plus checkpoint prefix plus self.path) + :return: list of all saved paths + """ + save_path = self._saver.save(sess, save_path) + return [save_path] + + def restore(self, sess: Any, restore_path: str): + """ + Restore from restore_path + :param sess: active session for session-based frameworks (e.g. TF) + :param restore_path: full path to load checkpoint from. + """ + # We don't use saver.restore() because checkpoint is loaded to online network, but if the checkpoint + # is from the global network, a namespace mismatch exists and variable name must be modified before loading. + variables = dict() + reader = tf.contrib.framework.load_checkpoint(restore_path) + for var_name, _ in reader.get_variable_to_shape_map().items(): + # if variable was saved using global network, re-map it to online network + # TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here? + new_name = var_name.replace('global/', 'online/') + variables[new_name] = reader.get_tensor(var_name) + # Assign all variables + sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables]) + + def merge(self, other: 'Saver'): + """ + Merge other saver into this saver + :param other: saver to be merged into self + """ + assert isinstance(other, GlobalVariableSaver) + self._names.extend(other._names) + # There is nothing else to do because variables must already be part of the global collection. diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 05153c1..1166be3 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -232,19 +232,11 @@ class GraphManager(object): # set the session for all the modules self.set_session(self.sess) else: - self.variables_to_restore = tf.global_variables() - # self.variables_to_restore = [v for v in self.variables_to_restore if '/online' in v.name] TODO: is this necessary? - self.checkpoint_saver = tf.train.Saver(self.variables_to_restore) - # regular session self.sess = tf.Session(config=config) - # set the session for all the modules self.set_session(self.sess) - # restore from checkpoint if given - self.restore_checkpoint() - # the TF graph is static, and therefore is saved once - in the beginning of the experiment if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir: self.save_graph() @@ -254,11 +246,6 @@ class GraphManager(object): Call set_session to initialize parameters and construct checkpoint_saver """ self.set_session(sess=None) # Initialize all modules - self.checkpoint_saver = SaverCollection() - for level in self.level_managers: - self.checkpoint_saver.update(level.collect_savers()) - # restore from checkpoint if given - self.restore_checkpoint() def create_session(self, task_parameters: TaskParameters): if task_parameters.framework_type == Frameworks.tensorflow: @@ -268,6 +255,13 @@ class GraphManager(object): else: raise ValueError('Invalid framework {}'.format(task_parameters.framework_type)) + # Create parameter saver + self.checkpoint_saver = SaverCollection() + for level in self.level_managers: + self.checkpoint_saver.update(level.collect_savers()) + # restore from checkpoint if given + self.restore_checkpoint() + def save_graph(self) -> None: """ Save the TF graph to a protobuf description file in the experiment directory @@ -566,16 +560,12 @@ class GraphManager(object): # TODO: find better way to load checkpoints that were saved with a global network into the online network if self.task_parameters.checkpoint_restore_dir: - checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) - screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) - - if self.task_parameters.framework_type == Frameworks.tensorflow: - self._restore_checkpoint_tf(checkpoint.model_checkpoint_path) - elif self.task_parameters.framework_type == Frameworks.mxnet: - self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) + if checkpoint is None: + screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) else: - raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type)) + screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) + self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] @@ -598,10 +588,7 @@ class GraphManager(object): if not os.path.exists(os.path.dirname(checkpoint_path)): os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure if not isinstance(self.task_parameters, DistributedTaskParameters): - if self.checkpoint_saver is not None: - saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) - else: - saved_checkpoint_path = "" + saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) else: saved_checkpoint_path = checkpoint_path diff --git a/rl_coach/saver.py b/rl_coach/saver.py index ce22c1f..856e75e 100644 --- a/rl_coach/saver.py +++ b/rl_coach/saver.py @@ -88,7 +88,7 @@ class SaverCollection(object): """ paths = list() for saver in self: - paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path))) + paths.extend(saver.save(sess, self._full_path(save_path, saver))) return paths def restore(self, sess: Any, restore_path: str) -> None: @@ -98,8 +98,7 @@ class SaverCollection(object): :param restore_path: path for restoring checkpoint using savers. """ for saver in self: - restore_path = "{}.{}".format(restore_path, saver.path) - saver.restore(sess, restore_path) + saver.restore(sess, self._full_path(restore_path, saver)) def __iter__(self): """ @@ -108,5 +107,16 @@ class SaverCollection(object): """ return (v for v in self._saver_dict.values()) + @staticmethod + def _full_path(path_prefix: str, saver: Saver) -> str: + """ + Concatenates path of the saver to parent prefix to create full save path + :param path_prefix: prefix of the path + :param saver: saver object to get unique path extension from + :return: full path + """ + if saver.path == "": + return path_prefix + return "{}.{}".format(path_prefix, saver.path) diff --git a/rl_coach/tests/test_utils.py b/rl_coach/tests/test_utils.py index 3ffd686..face47b 100644 --- a/rl_coach/tests/test_utils.py +++ b/rl_coach/tests/test_utils.py @@ -1,21 +1,17 @@ +import os import pytest +import tempfile from rl_coach import utils @pytest.mark.unit_test -def test_get_checkpoint_state_default(): - files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext'] - checkpoint_state = utils.get_checkpoint_state(files) - assert checkpoint_state.model_checkpoint_path == '4.test.ckpt' - assert checkpoint_state.all_model_checkpoint_paths == [f[:-4] for f in sorted(files)] - - -@pytest.mark.unit_test -def test_get_checkpoint_state_custom(): - files = ['prefix.4.test.ckpt.ext', 'prefix.2.test.ckpt.ext', 'prefix.3.test.ckpt.ext', 'prefix.1.test.ckpt.ext'] - assert len(utils.get_checkpoint_state(files).all_model_checkpoint_paths) == 0 # doesn't match the default pattern - checkpoint_state = utils.get_checkpoint_state(files, filename_pattern=r'([0-9]+)[^0-9].*?\.ckpt') - assert checkpoint_state.model_checkpoint_path == '4.test.ckpt' - assert checkpoint_state.all_model_checkpoint_paths == [f[7:-4] for f in sorted(files)] +def test_get_checkpoint_state(): + files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext'] + with tempfile.TemporaryDirectory() as temp_dir: + [open(os.path.join(temp_dir, fn), 'a').close() for fn in files] + checkpoint_state = utils.get_checkpoint_state(temp_dir) + assert checkpoint_state.model_checkpoint_path == os.path.join(temp_dir, '4.test.ckpt') + assert checkpoint_state.all_model_checkpoint_paths == \ + [os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])] diff --git a/rl_coach/utils.py b/rl_coach/utils.py index 4a4d248..7b5b95d 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -574,24 +574,30 @@ class CheckpointState(object): return str(self._checkpoints) -COACH_CHECKPOINT_PATTERN = r'\A([0-9]+)[^0-9].*?\.ckpt' +COACH_CHECKPOINT_PATTERN = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?' def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\ - CheckpointState: + Union[CheckpointState, None]: """ - Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(1)) to sort + Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(2) or group(4) to sort the checkpoint names and find the latest checkpoint :param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory :param filename_pattern: regex pattern for checkpoint filenames - :return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names + :return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names. If no matching + files are found, returns None. """ prog = re.compile(filename_pattern) checkpoints = dict() filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir for name in filenames: m = prog.search(name) - if m is not None and m.group(1) is not None: + if m is not None and (m.group(2) is not None or m.group(4) is not None): + if m.group(2) is not None and m.group(4) is not None: + assert m.group(2) == m.group(4) + checkpoint_count = int(m.group(2) if m.group(2) is not None else m.group(4)) full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0) - checkpoints[int(m.group(1))] = full_path + checkpoints[checkpoint_count] = full_path + if len(checkpoints) == 0: + return None return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])