1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

restoring from a checkpoint file (#247)

This commit is contained in:
Gal Leibovich
2019-03-17 16:28:09 +02:00
committed by GitHub
parent f03bd7ad93
commit d6158a5cfc
6 changed files with 87 additions and 39 deletions

View File

@@ -43,8 +43,9 @@ class DNDQHead(QHead):
self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad
def _build_module(self, input_layer): def _build_module(self, input_layer):
if hasattr(self.ap.task_parameters, 'checkpoint_restore_dir') and self.ap.task_parameters.checkpoint_restore_dir: if hasattr(self.ap.task_parameters, 'checkpoint_restore_path') and\
self.DND = differentiable_neural_dictionary.load_dnd(self.ap.task_parameters.checkpoint_restore_dir) self.ap.task_parameters.checkpoint_restore_path:
self.DND = differentiable_neural_dictionary.load_dnd(self.ap.task_parameters.checkpoint_restore_path)
else: else:
self.DND = differentiable_neural_dictionary.QDND( self.DND = differentiable_neural_dictionary.QDND(
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient, self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,

View File

@@ -25,6 +25,7 @@ from typing import Dict, List, Union
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase, \ from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase, \
SelectedPhaseOnlyDumpFilter, MaxDumpFilter SelectedPhaseOnlyDumpFilter, MaxDumpFilter
from rl_coach.filters.filter import NoInputFilter from rl_coach.filters.filter import NoInputFilter
from rl_coach.logger import screen
class Frameworks(Enum): class Frameworks(Enum):
@@ -552,8 +553,8 @@ class AgentParameters(Parameters):
class TaskParameters(Parameters): class TaskParameters(Parameters):
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: int=None, use_cpu: bool=False, def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: int=None, use_cpu: bool=False,
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False, checkpoint_restore_path=None, checkpoint_save_dir=None, export_onnx_graph: bool=False,
num_gpu: int=1): apply_stop_condition: bool=False, num_gpu: int=1):
""" """
:param framework_type: deep learning framework type. currently only tensorflow is supported :param framework_type: deep learning framework type. currently only tensorflow is supported
:param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps. :param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.
@@ -562,7 +563,10 @@ class TaskParameters(Parameters):
:param experiment_path: the path to the directory which will store all the experiment outputs :param experiment_path: the path to the directory which will store all the experiment outputs
:param seed: a seed to use for the random numbers generator :param seed: a seed to use for the random numbers generator
:param checkpoint_save_secs: the number of seconds between each checkpoint saving :param checkpoint_save_secs: the number of seconds between each checkpoint saving
:param checkpoint_restore_dir: the directory to restore the checkpoints from :param checkpoint_restore_dir:
[DEPECRATED - will be removed in one of the next releases - switch to checkpoint_restore_path]
the dir to restore the checkpoints from
:param checkpoint_restore_path: the path to restore the checkpoints from
:param checkpoint_save_dir: the directory to store the checkpoints in :param checkpoint_save_dir: the directory to store the checkpoints in
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
@@ -574,7 +578,13 @@ class TaskParameters(Parameters):
self.use_cpu = use_cpu self.use_cpu = use_cpu
self.experiment_path = experiment_path self.experiment_path = experiment_path
self.checkpoint_save_secs = checkpoint_save_secs self.checkpoint_save_secs = checkpoint_save_secs
self.checkpoint_restore_dir = checkpoint_restore_dir if checkpoint_restore_dir:
screen.warning('TaskParameters.checkpoint_restore_dir is DEPECRATED and will be removed in one of the next '
'releases. Please switch to using TaskParameters.checkpoint_restore_path, with your '
'directory path. ')
self.checkpoint_restore_path = checkpoint_restore_dir
else:
self.checkpoint_restore_path = checkpoint_restore_path
self.checkpoint_save_dir = checkpoint_save_dir self.checkpoint_save_dir = checkpoint_save_dir
self.seed = seed self.seed = seed
self.export_onnx_graph = export_onnx_graph self.export_onnx_graph = export_onnx_graph
@@ -586,7 +596,7 @@ class DistributedTaskParameters(TaskParameters):
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str, def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
task_index: int, evaluate_only: int=None, num_tasks: int=None, task_index: int, evaluate_only: int=None, num_tasks: int=None,
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None, num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_path=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False): checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
""" """
:param framework_type: deep learning framework type. currently only tensorflow is supported :param framework_type: deep learning framework type. currently only tensorflow is supported
@@ -604,7 +614,7 @@ class DistributedTaskParameters(TaskParameters):
:param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad. :param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.
:param seed: a seed to use for the random numbers generator :param seed: a seed to use for the random numbers generator
:param checkpoint_save_secs: the number of seconds between each checkpoint saving :param checkpoint_save_secs: the number of seconds between each checkpoint saving
:param checkpoint_restore_dir: the directory to restore the checkpoints from :param checkpoint_restore_path: the path to restore the checkpoints from
:param checkpoint_save_dir: the directory to store the checkpoints in :param checkpoint_save_dir: the directory to store the checkpoints in
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
@@ -612,7 +622,7 @@ class DistributedTaskParameters(TaskParameters):
""" """
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu, super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs, experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir, checkpoint_restore_path=checkpoint_restore_path, checkpoint_save_dir=checkpoint_save_dir,
export_onnx_graph=export_onnx_graph, apply_stop_condition=apply_stop_condition) export_onnx_graph=export_onnx_graph, apply_stop_condition=apply_stop_condition)
self.parameters_server_hosts = parameters_server_hosts self.parameters_server_hosts = parameters_server_hosts
self.worker_hosts = worker_hosts self.worker_hosts = worker_hosts

View File

@@ -33,6 +33,8 @@ from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskPa
from multiprocessing import Process from multiprocessing import Process
from multiprocessing.managers import BaseManager from multiprocessing.managers import BaseManager
import subprocess import subprocess
from glob import glob
from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
@@ -44,7 +46,7 @@ from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
from rl_coach.training_worker import training_worker from rl_coach.training_worker import training_worker
from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint from rl_coach.rollout_worker import rollout_worker
if len(set(failed_imports)) > 0: if len(set(failed_imports)) > 0:
@@ -110,7 +112,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
) )
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER: if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
task_parameters.checkpoint_restore_dir = ckpt_inside_container task_parameters.checkpoint_restore_path = ckpt_inside_container
data_store = None data_store = None
if args.data_store_params: if args.data_store_params:
@@ -394,6 +396,10 @@ class CoachLauncher(object):
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir): if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
screen.error("The requested checkpoint folder to load from does not exist.") screen.error("The requested checkpoint folder to load from does not exist.")
# validate the checkpoints args
if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'):
screen.error("The requested checkpoint file to load from does not exist.")
# no preset was given. check if the user requested to play some environment on its own # no preset was given. check if the user requested to play some environment on its own
if args.preset is None and args.play and not args.environment_type: if args.preset is None and args.play and not args.environment_type:
screen.error('When no preset is given for Coach to run, and the user requests human control over ' screen.error('When no preset is given for Coach to run, and the user requests human control over '
@@ -493,6 +499,9 @@ class CoachLauncher(object):
parser.add_argument('-crd', '--checkpoint_restore_dir', parser.add_argument('-crd', '--checkpoint_restore_dir',
help='(string) Path to a folder containing a checkpoint to restore the model from.', help='(string) Path to a folder containing a checkpoint to restore the model from.',
type=str) type=str)
parser.add_argument('-crf', '--checkpoint_restore_file',
help='(string) Path to a checkpoint file to restore the model from.',
type=str)
parser.add_argument('-dg', '--dump_gifs', parser.add_argument('-dg', '--dump_gifs',
help="(flag) Enable the gif saving functionality.", help="(flag) Enable the gif saving functionality.",
action='store_true') action='store_true')
@@ -607,6 +616,12 @@ class CoachLauncher(object):
atexit.register(logger.summarize_experiment) atexit.register(logger.summarize_experiment)
screen.change_terminal_title(args.experiment_name) screen.change_terminal_title(args.experiment_name)
if args.checkpoint_restore_dir is not None and args.checkpoint_restore_file is not None:
raise ValueError("Only one of the checkpoint_restore_dir and checkpoint_restore_file arguments can be used"
" simulatenously.")
checkpoint_restore_path = args.checkpoint_restore_dir if args.checkpoint_restore_dir \
else args.checkpoint_restore_file
task_parameters = TaskParameters( task_parameters = TaskParameters(
framework_type=args.framework, framework_type=args.framework,
evaluate_only=args.evaluate, evaluate_only=args.evaluate,
@@ -614,7 +629,7 @@ class CoachLauncher(object):
seed=args.seed, seed=args.seed,
use_cpu=args.use_cpu, use_cpu=args.use_cpu,
checkpoint_save_secs=args.checkpoint_save_secs, checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir, checkpoint_restore_path=checkpoint_restore_path,
checkpoint_save_dir=args.checkpoint_save_dir, checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph, export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition apply_stop_condition=args.apply_stop_condition
@@ -637,11 +652,13 @@ class CoachLauncher(object):
else: else:
self.start_multi_threaded(graph_manager, args) self.start_multi_threaded(graph_manager, args)
def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace): @staticmethod
def start_single_threaded(task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation # Start the training or evaluation
start_graph(graph_manager=graph_manager, task_parameters=task_parameters) start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace): @staticmethod
def start_multi_threaded(graph_manager: 'GraphManager', args: argparse.Namespace):
total_tasks = args.num_workers total_tasks = args.num_workers
if args.evaluation_worker: if args.evaluation_worker:
total_tasks += 1 total_tasks += 1
@@ -657,6 +674,10 @@ class CoachLauncher(object):
comm_manager.start() comm_manager.start()
shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad() shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad()
if args.checkpoint_restore_file:
raise ValueError("Multi-Process runs only support restoring checkpoints from a directory, "
"and not from a file. ")
def start_distributed_task(job_type, task_index, evaluation_worker=False, def start_distributed_task(job_type, task_index, evaluation_worker=False,
shared_memory_scratchpad=shared_memory_scratchpad): shared_memory_scratchpad=shared_memory_scratchpad):
task_parameters = DistributedTaskParameters( task_parameters = DistributedTaskParameters(
@@ -673,7 +694,7 @@ class CoachLauncher(object):
shared_memory_scratchpad=shared_memory_scratchpad, shared_memory_scratchpad=shared_memory_scratchpad,
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
checkpoint_save_secs=args.checkpoint_save_secs, checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir, checkpoint_restore_path=args.checkpoint_restore_dir, # MonitoredTrainingSession only supports a dir
checkpoint_save_dir=args.checkpoint_save_dir, checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph, export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition apply_stop_condition=args.apply_stop_condition

View File

@@ -25,7 +25,7 @@ import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \ from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \ VisualizationParameters, \
Parameters, PresetValidationParameters, RunType Parameters, PresetValidationParameters, RunType
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint, CheckpointState
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \ EnvironmentSteps, \
StepMethod, Transition StepMethod, Transition
@@ -218,11 +218,13 @@ class GraphManager(object):
if isinstance(task_parameters, DistributedTaskParameters): if isinstance(task_parameters, DistributedTaskParameters):
# the distributed tensorflow setting # the distributed tensorflow setting
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir: if hasattr(self.task_parameters, 'checkpoint_restore_path') and self.task_parameters.checkpoint_restore_path:
checkpoint_dir = os.path.join(task_parameters.experiment_path, 'checkpoint') checkpoint_dir = os.path.join(task_parameters.experiment_path, 'checkpoint')
if os.path.exists(checkpoint_dir): if os.path.exists(checkpoint_dir):
remove_tree(checkpoint_dir) remove_tree(checkpoint_dir)
copy_tree(task_parameters.checkpoint_restore_dir, checkpoint_dir) # in the locally distributed case, checkpoints are always restored from a directory (and not from a
# file)
copy_tree(task_parameters.checkpoint_restore_path, checkpoint_dir)
else: else:
checkpoint_dir = task_parameters.checkpoint_save_dir checkpoint_dir = task_parameters.checkpoint_save_dir
@@ -547,30 +549,44 @@ class GraphManager(object):
self.verify_graph_was_created() self.verify_graph_was_created()
# TODO: find better way to load checkpoints that were saved with a global network into the online network # TODO: find better way to load checkpoints that were saved with a global network into the online network
if self.task_parameters.checkpoint_restore_dir: if self.task_parameters.checkpoint_restore_path:
if self.task_parameters.framework_type == Frameworks.tensorflow and\ if os.path.isdir(self.task_parameters.checkpoint_restore_path):
'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_dir): # a checkpoint dir
# TODO-fixme checkpointing if self.task_parameters.framework_type == Frameworks.tensorflow and\
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so, 'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_path):
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename # TODO-fixme checkpointing
# pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
# Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
# the checkpoint names defined do not match the actual checkpoint names. # filename pattern. The names used are maintained in a CheckpointState protobuf file named
checkpoint = self._get_checkpoint_state_tf() # 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
# restore the model, as the checkpoint names defined do not match the actual checkpoint names.
checkpoint = self._get_checkpoint_state_tf(self.task_parameters.checkpoint_restore_path)
else:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_path)
if checkpoint is None:
raise ValueError("No checkpoint to restore in: {}".format(
self.task_parameters.checkpoint_restore_path))
model_checkpoint_path = checkpoint.model_checkpoint_path
checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
else: else:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) # a checkpoint file
if self.task_parameters.framework_type == Frameworks.tensorflow:
model_checkpoint_path = self.task_parameters.checkpoint_restore_path
checkpoint_restore_dir = os.path.dirname(model_checkpoint_path)
else:
raise ValueError("Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
" only supported when with tensorflow.")
if checkpoint is None: screen.log_title("Loading checkpoint: {}".format(model_checkpoint_path))
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else:
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] self.checkpoint_saver.restore(self.sess, model_checkpoint_path)
def _get_checkpoint_state_tf(self): [manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
import tensorflow as tf import tensorflow as tf
return tf.train.get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) return tf.train.get_checkpoint_state(checkpoint_restore_dir)
def occasionally_save_checkpoint(self): def occasionally_save_checkpoint(self):
# only the chief process saves checkpoints # only the chief process saves checkpoints

View File

@@ -67,7 +67,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
""" """
wait for first checkpoint then perform rollouts using the model wait for first checkpoint then perform rollouts using the model
""" """
checkpoint_dir = task_parameters.checkpoint_restore_dir checkpoint_dir = task_parameters.checkpoint_restore_path
wait_for_checkpoint(checkpoint_dir, data_store) wait_for_checkpoint(checkpoint_dir, data_store)
graph_manager.create_graph(task_parameters) graph_manager.create_graph(task_parameters)

View File

@@ -56,7 +56,7 @@ def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restor
# graph_manager.evaluate(EnvironmentSteps(1000)) # graph_manager.evaluate(EnvironmentSteps(1000))
# graph_manager.save_checkpoint() # graph_manager.save_checkpoint()
# #
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" # graph_manager.task_parameters.checkpoint_restore_path = "./experiments/test/checkpoint"
# while True: # while True:
# graph_manager.restore_checkpoint() # graph_manager.restore_checkpoint()
# graph_manager.evaluate(EnvironmentSteps(1000)) # graph_manager.evaluate(EnvironmentSteps(1000))