mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
restoring from a checkpoint file (#247)
This commit is contained in:
@@ -33,6 +33,8 @@ from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskPa
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.managers import BaseManager
|
||||
import subprocess
|
||||
from glob import glob
|
||||
|
||||
from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager
|
||||
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -44,7 +46,7 @@ from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
||||
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||
from rl_coach.training_worker import training_worker
|
||||
from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint
|
||||
from rl_coach.rollout_worker import rollout_worker
|
||||
|
||||
|
||||
if len(set(failed_imports)) > 0:
|
||||
@@ -110,7 +112,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
||||
)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||
task_parameters.checkpoint_restore_dir = ckpt_inside_container
|
||||
task_parameters.checkpoint_restore_path = ckpt_inside_container
|
||||
|
||||
data_store = None
|
||||
if args.data_store_params:
|
||||
@@ -394,6 +396,10 @@ class CoachLauncher(object):
|
||||
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||
|
||||
# validate the checkpoints args
|
||||
if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'):
|
||||
screen.error("The requested checkpoint file to load from does not exist.")
|
||||
|
||||
# no preset was given. check if the user requested to play some environment on its own
|
||||
if args.preset is None and args.play and not args.environment_type:
|
||||
screen.error('When no preset is given for Coach to run, and the user requests human control over '
|
||||
@@ -493,6 +499,9 @@ class CoachLauncher(object):
|
||||
parser.add_argument('-crd', '--checkpoint_restore_dir',
|
||||
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
||||
type=str)
|
||||
parser.add_argument('-crf', '--checkpoint_restore_file',
|
||||
help='(string) Path to a checkpoint file to restore the model from.',
|
||||
type=str)
|
||||
parser.add_argument('-dg', '--dump_gifs',
|
||||
help="(flag) Enable the gif saving functionality.",
|
||||
action='store_true')
|
||||
@@ -607,6 +616,12 @@ class CoachLauncher(object):
|
||||
atexit.register(logger.summarize_experiment)
|
||||
screen.change_terminal_title(args.experiment_name)
|
||||
|
||||
if args.checkpoint_restore_dir is not None and args.checkpoint_restore_file is not None:
|
||||
raise ValueError("Only one of the checkpoint_restore_dir and checkpoint_restore_file arguments can be used"
|
||||
" simulatenously.")
|
||||
checkpoint_restore_path = args.checkpoint_restore_dir if args.checkpoint_restore_dir \
|
||||
else args.checkpoint_restore_file
|
||||
|
||||
task_parameters = TaskParameters(
|
||||
framework_type=args.framework,
|
||||
evaluate_only=args.evaluate,
|
||||
@@ -614,7 +629,7 @@ class CoachLauncher(object):
|
||||
seed=args.seed,
|
||||
use_cpu=args.use_cpu,
|
||||
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
||||
checkpoint_restore_path=checkpoint_restore_path,
|
||||
checkpoint_save_dir=args.checkpoint_save_dir,
|
||||
export_onnx_graph=args.export_onnx_graph,
|
||||
apply_stop_condition=args.apply_stop_condition
|
||||
@@ -637,11 +652,13 @@ class CoachLauncher(object):
|
||||
else:
|
||||
self.start_multi_threaded(graph_manager, args)
|
||||
|
||||
def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
@staticmethod
|
||||
def start_single_threaded(task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
# Start the training or evaluation
|
||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||
|
||||
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
@staticmethod
|
||||
def start_multi_threaded(graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
total_tasks = args.num_workers
|
||||
if args.evaluation_worker:
|
||||
total_tasks += 1
|
||||
@@ -657,6 +674,10 @@ class CoachLauncher(object):
|
||||
comm_manager.start()
|
||||
shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad()
|
||||
|
||||
if args.checkpoint_restore_file:
|
||||
raise ValueError("Multi-Process runs only support restoring checkpoints from a directory, "
|
||||
"and not from a file. ")
|
||||
|
||||
def start_distributed_task(job_type, task_index, evaluation_worker=False,
|
||||
shared_memory_scratchpad=shared_memory_scratchpad):
|
||||
task_parameters = DistributedTaskParameters(
|
||||
@@ -673,7 +694,7 @@ class CoachLauncher(object):
|
||||
shared_memory_scratchpad=shared_memory_scratchpad,
|
||||
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
|
||||
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
||||
checkpoint_restore_path=args.checkpoint_restore_dir, # MonitoredTrainingSession only supports a dir
|
||||
checkpoint_save_dir=args.checkpoint_save_dir,
|
||||
export_onnx_graph=args.export_onnx_graph,
|
||||
apply_stop_condition=args.apply_stop_condition
|
||||
|
||||
Reference in New Issue
Block a user