1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

restoring from a checkpoint file (#247)

This commit is contained in:
Gal Leibovich
2019-03-17 16:28:09 +02:00
committed by GitHub
parent f03bd7ad93
commit d6158a5cfc
6 changed files with 87 additions and 39 deletions

View File

@@ -25,7 +25,7 @@ import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \
Parameters, PresetValidationParameters, RunType
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint, CheckpointState
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \
StepMethod, Transition
@@ -218,11 +218,13 @@ class GraphManager(object):
if isinstance(task_parameters, DistributedTaskParameters):
# the distributed tensorflow setting
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
if hasattr(self.task_parameters, 'checkpoint_restore_path') and self.task_parameters.checkpoint_restore_path:
checkpoint_dir = os.path.join(task_parameters.experiment_path, 'checkpoint')
if os.path.exists(checkpoint_dir):
remove_tree(checkpoint_dir)
copy_tree(task_parameters.checkpoint_restore_dir, checkpoint_dir)
# in the locally distributed case, checkpoints are always restored from a directory (and not from a
# file)
copy_tree(task_parameters.checkpoint_restore_path, checkpoint_dir)
else:
checkpoint_dir = task_parameters.checkpoint_save_dir
@@ -547,30 +549,44 @@ class GraphManager(object):
self.verify_graph_was_created()
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if self.task_parameters.checkpoint_restore_dir:
if self.task_parameters.framework_type == Frameworks.tensorflow and\
'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_dir):
# TODO-fixme checkpointing
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename
# pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using
# Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as
# the checkpoint names defined do not match the actual checkpoint names.
checkpoint = self._get_checkpoint_state_tf()
if self.task_parameters.checkpoint_restore_path:
if os.path.isdir(self.task_parameters.checkpoint_restore_path):
# a checkpoint dir
if self.task_parameters.framework_type == Frameworks.tensorflow and\
'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_path):
# TODO-fixme checkpointing
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
# filename pattern. The names used are maintained in a CheckpointState protobuf file named
# 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
# restore the model, as the checkpoint names defined do not match the actual checkpoint names.
checkpoint = self._get_checkpoint_state_tf(self.task_parameters.checkpoint_restore_path)
else:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_path)
if checkpoint is None:
raise ValueError("No checkpoint to restore in: {}".format(
self.task_parameters.checkpoint_restore_path))
model_checkpoint_path = checkpoint.model_checkpoint_path
checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
else:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
# a checkpoint file
if self.task_parameters.framework_type == Frameworks.tensorflow:
model_checkpoint_path = self.task_parameters.checkpoint_restore_path
checkpoint_restore_dir = os.path.dirname(model_checkpoint_path)
else:
raise ValueError("Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
" only supported when with tensorflow.")
if checkpoint is None:
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else:
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
screen.log_title("Loading checkpoint: {}".format(model_checkpoint_path))
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
self.checkpoint_saver.restore(self.sess, model_checkpoint_path)
def _get_checkpoint_state_tf(self):
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
import tensorflow as tf
return tf.train.get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
def occasionally_save_checkpoint(self):
# only the chief process saves checkpoints