mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
restoring from a checkpoint file (#247)
This commit is contained in:
@@ -43,8 +43,9 @@ class DNDQHead(QHead):
|
|||||||
self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad
|
self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad
|
||||||
|
|
||||||
def _build_module(self, input_layer):
|
def _build_module(self, input_layer):
|
||||||
if hasattr(self.ap.task_parameters, 'checkpoint_restore_dir') and self.ap.task_parameters.checkpoint_restore_dir:
|
if hasattr(self.ap.task_parameters, 'checkpoint_restore_path') and\
|
||||||
self.DND = differentiable_neural_dictionary.load_dnd(self.ap.task_parameters.checkpoint_restore_dir)
|
self.ap.task_parameters.checkpoint_restore_path:
|
||||||
|
self.DND = differentiable_neural_dictionary.load_dnd(self.ap.task_parameters.checkpoint_restore_path)
|
||||||
else:
|
else:
|
||||||
self.DND = differentiable_neural_dictionary.QDND(
|
self.DND = differentiable_neural_dictionary.QDND(
|
||||||
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
|
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from typing import Dict, List, Union
|
|||||||
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase, \
|
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase, \
|
||||||
SelectedPhaseOnlyDumpFilter, MaxDumpFilter
|
SelectedPhaseOnlyDumpFilter, MaxDumpFilter
|
||||||
from rl_coach.filters.filter import NoInputFilter
|
from rl_coach.filters.filter import NoInputFilter
|
||||||
|
from rl_coach.logger import screen
|
||||||
|
|
||||||
|
|
||||||
class Frameworks(Enum):
|
class Frameworks(Enum):
|
||||||
@@ -552,8 +553,8 @@ class AgentParameters(Parameters):
|
|||||||
class TaskParameters(Parameters):
|
class TaskParameters(Parameters):
|
||||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: int=None, use_cpu: bool=False,
|
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: int=None, use_cpu: bool=False,
|
||||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||||
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False,
|
checkpoint_restore_path=None, checkpoint_save_dir=None, export_onnx_graph: bool=False,
|
||||||
num_gpu: int=1):
|
apply_stop_condition: bool=False, num_gpu: int=1):
|
||||||
"""
|
"""
|
||||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||||
:param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.
|
:param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.
|
||||||
@@ -562,7 +563,10 @@ class TaskParameters(Parameters):
|
|||||||
:param experiment_path: the path to the directory which will store all the experiment outputs
|
:param experiment_path: the path to the directory which will store all the experiment outputs
|
||||||
:param seed: a seed to use for the random numbers generator
|
:param seed: a seed to use for the random numbers generator
|
||||||
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
||||||
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
:param checkpoint_restore_dir:
|
||||||
|
[DEPECRATED - will be removed in one of the next releases - switch to checkpoint_restore_path]
|
||||||
|
the dir to restore the checkpoints from
|
||||||
|
:param checkpoint_restore_path: the path to restore the checkpoints from
|
||||||
:param checkpoint_save_dir: the directory to store the checkpoints in
|
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||||
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
||||||
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
|
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
|
||||||
@@ -574,7 +578,13 @@ class TaskParameters(Parameters):
|
|||||||
self.use_cpu = use_cpu
|
self.use_cpu = use_cpu
|
||||||
self.experiment_path = experiment_path
|
self.experiment_path = experiment_path
|
||||||
self.checkpoint_save_secs = checkpoint_save_secs
|
self.checkpoint_save_secs = checkpoint_save_secs
|
||||||
self.checkpoint_restore_dir = checkpoint_restore_dir
|
if checkpoint_restore_dir:
|
||||||
|
screen.warning('TaskParameters.checkpoint_restore_dir is DEPECRATED and will be removed in one of the next '
|
||||||
|
'releases. Please switch to using TaskParameters.checkpoint_restore_path, with your '
|
||||||
|
'directory path. ')
|
||||||
|
self.checkpoint_restore_path = checkpoint_restore_dir
|
||||||
|
else:
|
||||||
|
self.checkpoint_restore_path = checkpoint_restore_path
|
||||||
self.checkpoint_save_dir = checkpoint_save_dir
|
self.checkpoint_save_dir = checkpoint_save_dir
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.export_onnx_graph = export_onnx_graph
|
self.export_onnx_graph = export_onnx_graph
|
||||||
@@ -586,7 +596,7 @@ class DistributedTaskParameters(TaskParameters):
|
|||||||
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
||||||
task_index: int, evaluate_only: int=None, num_tasks: int=None,
|
task_index: int, evaluate_only: int=None, num_tasks: int=None,
|
||||||
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
||||||
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_path=None,
|
||||||
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
|
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
|
||||||
"""
|
"""
|
||||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||||
@@ -604,7 +614,7 @@ class DistributedTaskParameters(TaskParameters):
|
|||||||
:param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.
|
:param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.
|
||||||
:param seed: a seed to use for the random numbers generator
|
:param seed: a seed to use for the random numbers generator
|
||||||
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
||||||
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
:param checkpoint_restore_path: the path to restore the checkpoints from
|
||||||
:param checkpoint_save_dir: the directory to store the checkpoints in
|
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||||
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
||||||
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
|
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
|
||||||
@@ -612,7 +622,7 @@ class DistributedTaskParameters(TaskParameters):
|
|||||||
"""
|
"""
|
||||||
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
|
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
|
||||||
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
|
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
|
||||||
checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir,
|
checkpoint_restore_path=checkpoint_restore_path, checkpoint_save_dir=checkpoint_save_dir,
|
||||||
export_onnx_graph=export_onnx_graph, apply_stop_condition=apply_stop_condition)
|
export_onnx_graph=export_onnx_graph, apply_stop_condition=apply_stop_condition)
|
||||||
self.parameters_server_hosts = parameters_server_hosts
|
self.parameters_server_hosts = parameters_server_hosts
|
||||||
self.worker_hosts = worker_hosts
|
self.worker_hosts = worker_hosts
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskPa
|
|||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from multiprocessing.managers import BaseManager
|
from multiprocessing.managers import BaseManager
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager
|
from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager
|
||||||
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir
|
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir
|
||||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||||
@@ -44,7 +46,7 @@ from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
|||||||
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||||
from rl_coach.training_worker import training_worker
|
from rl_coach.training_worker import training_worker
|
||||||
from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint
|
from rl_coach.rollout_worker import rollout_worker
|
||||||
|
|
||||||
|
|
||||||
if len(set(failed_imports)) > 0:
|
if len(set(failed_imports)) > 0:
|
||||||
@@ -110,7 +112,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||||
task_parameters.checkpoint_restore_dir = ckpt_inside_container
|
task_parameters.checkpoint_restore_path = ckpt_inside_container
|
||||||
|
|
||||||
data_store = None
|
data_store = None
|
||||||
if args.data_store_params:
|
if args.data_store_params:
|
||||||
@@ -394,6 +396,10 @@ class CoachLauncher(object):
|
|||||||
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
||||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||||
|
|
||||||
|
# validate the checkpoints args
|
||||||
|
if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'):
|
||||||
|
screen.error("The requested checkpoint file to load from does not exist.")
|
||||||
|
|
||||||
# no preset was given. check if the user requested to play some environment on its own
|
# no preset was given. check if the user requested to play some environment on its own
|
||||||
if args.preset is None and args.play and not args.environment_type:
|
if args.preset is None and args.play and not args.environment_type:
|
||||||
screen.error('When no preset is given for Coach to run, and the user requests human control over '
|
screen.error('When no preset is given for Coach to run, and the user requests human control over '
|
||||||
@@ -493,6 +499,9 @@ class CoachLauncher(object):
|
|||||||
parser.add_argument('-crd', '--checkpoint_restore_dir',
|
parser.add_argument('-crd', '--checkpoint_restore_dir',
|
||||||
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
||||||
type=str)
|
type=str)
|
||||||
|
parser.add_argument('-crf', '--checkpoint_restore_file',
|
||||||
|
help='(string) Path to a checkpoint file to restore the model from.',
|
||||||
|
type=str)
|
||||||
parser.add_argument('-dg', '--dump_gifs',
|
parser.add_argument('-dg', '--dump_gifs',
|
||||||
help="(flag) Enable the gif saving functionality.",
|
help="(flag) Enable the gif saving functionality.",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
@@ -607,6 +616,12 @@ class CoachLauncher(object):
|
|||||||
atexit.register(logger.summarize_experiment)
|
atexit.register(logger.summarize_experiment)
|
||||||
screen.change_terminal_title(args.experiment_name)
|
screen.change_terminal_title(args.experiment_name)
|
||||||
|
|
||||||
|
if args.checkpoint_restore_dir is not None and args.checkpoint_restore_file is not None:
|
||||||
|
raise ValueError("Only one of the checkpoint_restore_dir and checkpoint_restore_file arguments can be used"
|
||||||
|
" simulatenously.")
|
||||||
|
checkpoint_restore_path = args.checkpoint_restore_dir if args.checkpoint_restore_dir \
|
||||||
|
else args.checkpoint_restore_file
|
||||||
|
|
||||||
task_parameters = TaskParameters(
|
task_parameters = TaskParameters(
|
||||||
framework_type=args.framework,
|
framework_type=args.framework,
|
||||||
evaluate_only=args.evaluate,
|
evaluate_only=args.evaluate,
|
||||||
@@ -614,7 +629,7 @@ class CoachLauncher(object):
|
|||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
use_cpu=args.use_cpu,
|
use_cpu=args.use_cpu,
|
||||||
checkpoint_save_secs=args.checkpoint_save_secs,
|
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||||
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
checkpoint_restore_path=checkpoint_restore_path,
|
||||||
checkpoint_save_dir=args.checkpoint_save_dir,
|
checkpoint_save_dir=args.checkpoint_save_dir,
|
||||||
export_onnx_graph=args.export_onnx_graph,
|
export_onnx_graph=args.export_onnx_graph,
|
||||||
apply_stop_condition=args.apply_stop_condition
|
apply_stop_condition=args.apply_stop_condition
|
||||||
@@ -637,11 +652,13 @@ class CoachLauncher(object):
|
|||||||
else:
|
else:
|
||||||
self.start_multi_threaded(graph_manager, args)
|
self.start_multi_threaded(graph_manager, args)
|
||||||
|
|
||||||
def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
|
@staticmethod
|
||||||
|
def start_single_threaded(task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||||
# Start the training or evaluation
|
# Start the training or evaluation
|
||||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||||
|
|
||||||
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
@staticmethod
|
||||||
|
def start_multi_threaded(graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||||
total_tasks = args.num_workers
|
total_tasks = args.num_workers
|
||||||
if args.evaluation_worker:
|
if args.evaluation_worker:
|
||||||
total_tasks += 1
|
total_tasks += 1
|
||||||
@@ -657,6 +674,10 @@ class CoachLauncher(object):
|
|||||||
comm_manager.start()
|
comm_manager.start()
|
||||||
shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad()
|
shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad()
|
||||||
|
|
||||||
|
if args.checkpoint_restore_file:
|
||||||
|
raise ValueError("Multi-Process runs only support restoring checkpoints from a directory, "
|
||||||
|
"and not from a file. ")
|
||||||
|
|
||||||
def start_distributed_task(job_type, task_index, evaluation_worker=False,
|
def start_distributed_task(job_type, task_index, evaluation_worker=False,
|
||||||
shared_memory_scratchpad=shared_memory_scratchpad):
|
shared_memory_scratchpad=shared_memory_scratchpad):
|
||||||
task_parameters = DistributedTaskParameters(
|
task_parameters = DistributedTaskParameters(
|
||||||
@@ -673,7 +694,7 @@ class CoachLauncher(object):
|
|||||||
shared_memory_scratchpad=shared_memory_scratchpad,
|
shared_memory_scratchpad=shared_memory_scratchpad,
|
||||||
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
|
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
|
||||||
checkpoint_save_secs=args.checkpoint_save_secs,
|
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||||
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
checkpoint_restore_path=args.checkpoint_restore_dir, # MonitoredTrainingSession only supports a dir
|
||||||
checkpoint_save_dir=args.checkpoint_save_dir,
|
checkpoint_save_dir=args.checkpoint_save_dir,
|
||||||
export_onnx_graph=args.export_onnx_graph,
|
export_onnx_graph=args.export_onnx_graph,
|
||||||
apply_stop_condition=args.apply_stop_condition
|
apply_stop_condition=args.apply_stop_condition
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import contextlib
|
|||||||
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
|
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
|
||||||
VisualizationParameters, \
|
VisualizationParameters, \
|
||||||
Parameters, PresetValidationParameters, RunType
|
Parameters, PresetValidationParameters, RunType
|
||||||
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint
|
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint, CheckpointState
|
||||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||||
EnvironmentSteps, \
|
EnvironmentSteps, \
|
||||||
StepMethod, Transition
|
StepMethod, Transition
|
||||||
@@ -218,11 +218,13 @@ class GraphManager(object):
|
|||||||
if isinstance(task_parameters, DistributedTaskParameters):
|
if isinstance(task_parameters, DistributedTaskParameters):
|
||||||
# the distributed tensorflow setting
|
# the distributed tensorflow setting
|
||||||
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
|
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
|
||||||
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
if hasattr(self.task_parameters, 'checkpoint_restore_path') and self.task_parameters.checkpoint_restore_path:
|
||||||
checkpoint_dir = os.path.join(task_parameters.experiment_path, 'checkpoint')
|
checkpoint_dir = os.path.join(task_parameters.experiment_path, 'checkpoint')
|
||||||
if os.path.exists(checkpoint_dir):
|
if os.path.exists(checkpoint_dir):
|
||||||
remove_tree(checkpoint_dir)
|
remove_tree(checkpoint_dir)
|
||||||
copy_tree(task_parameters.checkpoint_restore_dir, checkpoint_dir)
|
# in the locally distributed case, checkpoints are always restored from a directory (and not from a
|
||||||
|
# file)
|
||||||
|
copy_tree(task_parameters.checkpoint_restore_path, checkpoint_dir)
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = task_parameters.checkpoint_save_dir
|
checkpoint_dir = task_parameters.checkpoint_save_dir
|
||||||
|
|
||||||
@@ -547,30 +549,44 @@ class GraphManager(object):
|
|||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||||
if self.task_parameters.checkpoint_restore_dir:
|
if self.task_parameters.checkpoint_restore_path:
|
||||||
if self.task_parameters.framework_type == Frameworks.tensorflow and\
|
if os.path.isdir(self.task_parameters.checkpoint_restore_path):
|
||||||
'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_dir):
|
# a checkpoint dir
|
||||||
# TODO-fixme checkpointing
|
if self.task_parameters.framework_type == Frameworks.tensorflow and\
|
||||||
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
|
'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_path):
|
||||||
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename
|
# TODO-fixme checkpointing
|
||||||
# pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using
|
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
|
||||||
# Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as
|
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
|
||||||
# the checkpoint names defined do not match the actual checkpoint names.
|
# filename pattern. The names used are maintained in a CheckpointState protobuf file named
|
||||||
checkpoint = self._get_checkpoint_state_tf()
|
# 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
|
||||||
|
# restore the model, as the checkpoint names defined do not match the actual checkpoint names.
|
||||||
|
checkpoint = self._get_checkpoint_state_tf(self.task_parameters.checkpoint_restore_path)
|
||||||
|
else:
|
||||||
|
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_path)
|
||||||
|
|
||||||
|
if checkpoint is None:
|
||||||
|
raise ValueError("No checkpoint to restore in: {}".format(
|
||||||
|
self.task_parameters.checkpoint_restore_path))
|
||||||
|
model_checkpoint_path = checkpoint.model_checkpoint_path
|
||||||
|
checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
|
||||||
else:
|
else:
|
||||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
# a checkpoint file
|
||||||
|
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||||
|
model_checkpoint_path = self.task_parameters.checkpoint_restore_path
|
||||||
|
checkpoint_restore_dir = os.path.dirname(model_checkpoint_path)
|
||||||
|
else:
|
||||||
|
raise ValueError("Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
|
||||||
|
" only supported when with tensorflow.")
|
||||||
|
|
||||||
if checkpoint is None:
|
screen.log_title("Loading checkpoint: {}".format(model_checkpoint_path))
|
||||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
|
||||||
else:
|
|
||||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
|
||||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
|
||||||
|
|
||||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
self.checkpoint_saver.restore(self.sess, model_checkpoint_path)
|
||||||
|
|
||||||
def _get_checkpoint_state_tf(self):
|
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
|
||||||
|
|
||||||
|
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
return tf.train.get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
|
||||||
|
|
||||||
def occasionally_save_checkpoint(self):
|
def occasionally_save_checkpoint(self):
|
||||||
# only the chief process saves checkpoints
|
# only the chief process saves checkpoints
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
|||||||
"""
|
"""
|
||||||
wait for first checkpoint then perform rollouts using the model
|
wait for first checkpoint then perform rollouts using the model
|
||||||
"""
|
"""
|
||||||
checkpoint_dir = task_parameters.checkpoint_restore_dir
|
checkpoint_dir = task_parameters.checkpoint_restore_path
|
||||||
wait_for_checkpoint(checkpoint_dir, data_store)
|
wait_for_checkpoint(checkpoint_dir, data_store)
|
||||||
|
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restor
|
|||||||
# graph_manager.evaluate(EnvironmentSteps(1000))
|
# graph_manager.evaluate(EnvironmentSteps(1000))
|
||||||
# graph_manager.save_checkpoint()
|
# graph_manager.save_checkpoint()
|
||||||
#
|
#
|
||||||
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
|
# graph_manager.task_parameters.checkpoint_restore_path = "./experiments/test/checkpoint"
|
||||||
# while True:
|
# while True:
|
||||||
# graph_manager.restore_checkpoint()
|
# graph_manager.restore_checkpoint()
|
||||||
# graph_manager.evaluate(EnvironmentSteps(1000))
|
# graph_manager.evaluate(EnvironmentSteps(1000))
|
||||||
|
|||||||
Reference in New Issue
Block a user