mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
restoring from a checkpoint file (#247)
This commit is contained in:
@@ -25,7 +25,7 @@ import contextlib
|
||||
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
|
||||
VisualizationParameters, \
|
||||
Parameters, PresetValidationParameters, RunType
|
||||
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint
|
||||
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint, CheckpointState
|
||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||
EnvironmentSteps, \
|
||||
StepMethod, Transition
|
||||
@@ -218,11 +218,13 @@ class GraphManager(object):
|
||||
if isinstance(task_parameters, DistributedTaskParameters):
|
||||
# the distributed tensorflow setting
|
||||
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
|
||||
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
||||
if hasattr(self.task_parameters, 'checkpoint_restore_path') and self.task_parameters.checkpoint_restore_path:
|
||||
checkpoint_dir = os.path.join(task_parameters.experiment_path, 'checkpoint')
|
||||
if os.path.exists(checkpoint_dir):
|
||||
remove_tree(checkpoint_dir)
|
||||
copy_tree(task_parameters.checkpoint_restore_dir, checkpoint_dir)
|
||||
# in the locally distributed case, checkpoints are always restored from a directory (and not from a
|
||||
# file)
|
||||
copy_tree(task_parameters.checkpoint_restore_path, checkpoint_dir)
|
||||
else:
|
||||
checkpoint_dir = task_parameters.checkpoint_save_dir
|
||||
|
||||
@@ -547,30 +549,44 @@ class GraphManager(object):
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||
if self.task_parameters.checkpoint_restore_dir:
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow and\
|
||||
'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_dir):
|
||||
# TODO-fixme checkpointing
|
||||
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
|
||||
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename
|
||||
# pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using
|
||||
# Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as
|
||||
# the checkpoint names defined do not match the actual checkpoint names.
|
||||
checkpoint = self._get_checkpoint_state_tf()
|
||||
if self.task_parameters.checkpoint_restore_path:
|
||||
if os.path.isdir(self.task_parameters.checkpoint_restore_path):
|
||||
# a checkpoint dir
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow and\
|
||||
'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_path):
|
||||
# TODO-fixme checkpointing
|
||||
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
|
||||
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
|
||||
# filename pattern. The names used are maintained in a CheckpointState protobuf file named
|
||||
# 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
|
||||
# restore the model, as the checkpoint names defined do not match the actual checkpoint names.
|
||||
checkpoint = self._get_checkpoint_state_tf(self.task_parameters.checkpoint_restore_path)
|
||||
else:
|
||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_path)
|
||||
|
||||
if checkpoint is None:
|
||||
raise ValueError("No checkpoint to restore in: {}".format(
|
||||
self.task_parameters.checkpoint_restore_path))
|
||||
model_checkpoint_path = checkpoint.model_checkpoint_path
|
||||
checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
|
||||
else:
|
||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||
# a checkpoint file
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
model_checkpoint_path = self.task_parameters.checkpoint_restore_path
|
||||
checkpoint_restore_dir = os.path.dirname(model_checkpoint_path)
|
||||
else:
|
||||
raise ValueError("Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
|
||||
" only supported when with tensorflow.")
|
||||
|
||||
if checkpoint is None:
|
||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
||||
else:
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||
screen.log_title("Loading checkpoint: {}".format(model_checkpoint_path))
|
||||
|
||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||
self.checkpoint_saver.restore(self.sess, model_checkpoint_path)
|
||||
|
||||
def _get_checkpoint_state_tf(self):
|
||||
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
|
||||
|
||||
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
|
||||
import tensorflow as tf
|
||||
return tf.train.get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
|
||||
|
||||
def occasionally_save_checkpoint(self):
|
||||
# only the chief process saves checkpoints
|
||||
|
||||
Reference in New Issue
Block a user