mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Uploading checkpoint if crd provided (#191)
* Uploading checkpoint if crd provided * Changing the calculation of total steps because of a recent change in core_types Fixes #195
This commit is contained in:
committed by
Scott Leishman
parent
b3db9ce77d
commit
33dc29ee99
@@ -103,21 +103,22 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
||||
data_store_params.checkpoint_dir = ckpt_inside_container
|
||||
graph_manager.data_store_params = data_store_params
|
||||
|
||||
data_store = None
|
||||
if args.data_store_params:
|
||||
data_store = get_data_store(data_store_params)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||
training_worker(
|
||||
graph_manager=graph_manager,
|
||||
task_parameters=task_parameters,
|
||||
data_store=data_store,
|
||||
is_multi_node_test=args.is_multi_node_test
|
||||
)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||
task_parameters.checkpoint_restore_path = ckpt_inside_container
|
||||
|
||||
data_store = None
|
||||
if args.data_store_params:
|
||||
data_store = get_data_store(data_store_params)
|
||||
|
||||
rollout_worker(
|
||||
graph_manager=graph_manager,
|
||||
data_store=data_store,
|
||||
@@ -169,7 +170,7 @@ def handle_distributed_coach_orchestrator(args):
|
||||
memory_backend_parameters=memory_backend_params,
|
||||
data_store_params=ds_params_instance)
|
||||
orchestrator = Kubernetes(orchestration_params)
|
||||
if not orchestrator.setup():
|
||||
if not orchestrator.setup(args.checkpoint_restore_dir):
|
||||
print("Could not setup.")
|
||||
return 1
|
||||
|
||||
@@ -394,7 +395,9 @@ class CoachLauncher(object):
|
||||
|
||||
# validate the checkpoints args
|
||||
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||
# If distributed trainer, the checkpoint dir is not yet available so skipping the check in that case.
|
||||
if not (args.distributed_coach and args.distributed_coach_run_type in [RunType.TRAINER, RunType.ROLLOUT_WORKER]):
|
||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||
|
||||
# validate the checkpoints args
|
||||
if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'):
|
||||
|
||||
@@ -44,7 +44,11 @@ class DataStore(object):
|
||||
def load_from_store(self):
|
||||
pass
|
||||
|
||||
def setup_checkpoint_dir(self, crd=None):
|
||||
pass
|
||||
|
||||
|
||||
class SyncFiles(Enum):
|
||||
FINISHED = ".finished"
|
||||
LOCKFILE = ".lock"
|
||||
TRAINER_READY = ".ready"
|
||||
|
||||
@@ -284,3 +284,8 @@ class NFSDataStore(DataStore):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def setup_checkpoint_dir(self, crd=None):
|
||||
if crd:
|
||||
# TODO: find a way to upload this to the deployed nfs store.
|
||||
pass
|
||||
|
||||
@@ -77,6 +77,9 @@ class S3DataStore(DataStore):
|
||||
return True
|
||||
|
||||
def save_to_store(self):
|
||||
self._save_to_store(self.params.checkpoint_dir)
|
||||
|
||||
def _save_to_store(self, checkpoint_dir):
|
||||
"""
|
||||
save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
|
||||
uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
|
||||
@@ -88,24 +91,32 @@ class S3DataStore(DataStore):
|
||||
# Acquire lock
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
|
||||
|
||||
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
|
||||
state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
|
||||
if state_file.exists():
|
||||
ckpt_state = state_file.read()
|
||||
checkpoint_file = None
|
||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
||||
for root, dirs, files in os.walk(checkpoint_dir):
|
||||
for filename in files:
|
||||
if filename == CheckpointStateFile.checkpoint_state_filename:
|
||||
checkpoint_file = (root, filename)
|
||||
continue
|
||||
if filename.startswith(ckpt_state.name):
|
||||
abs_name = os.path.abspath(os.path.join(root, filename))
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
rel_name = os.path.relpath(abs_name, checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
|
||||
abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
rel_name = os.path.relpath(abs_name, checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
|
||||
# upload Finished if present
|
||||
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0)
|
||||
|
||||
# upload Ready if present
|
||||
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0)
|
||||
|
||||
# release lock
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
@@ -121,6 +132,7 @@ class S3DataStore(DataStore):
|
||||
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')):
|
||||
for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')):
|
||||
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
|
||||
|
||||
except ResponseError as e:
|
||||
print("Got exception: %s\n while saving to S3", e)
|
||||
|
||||
@@ -157,6 +169,18 @@ class S3DataStore(DataStore):
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Check if there's a ready file
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value)
|
||||
|
||||
if next(objects, None) is not None:
|
||||
try:
|
||||
self.mc.fget_object(
|
||||
self.params.bucket_name, SyncFiles.TRAINER_READY.value,
|
||||
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
checkpoint_state = state_file.read()
|
||||
if checkpoint_state is not None:
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
|
||||
@@ -167,3 +191,7 @@ class S3DataStore(DataStore):
|
||||
|
||||
except ResponseError as e:
|
||||
print("Got exception: %s\n while loading from S3", e)
|
||||
|
||||
def setup_checkpoint_dir(self, crd=None):
|
||||
if crd:
|
||||
self._save_to_store(crd)
|
||||
|
||||
@@ -37,6 +37,7 @@ from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
from rl_coach.checkpoint import CheckpointStateReader
|
||||
|
||||
from rl_coach.core_types import TimeTypes
|
||||
|
||||
@@ -589,6 +590,10 @@ class GraphManager(object):
|
||||
|
||||
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
|
||||
|
||||
# Set the last checkpoint ID
|
||||
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False)
|
||||
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
|
||||
|
||||
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
|
||||
import tensorflow as tf
|
||||
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
|
||||
@@ -721,6 +726,13 @@ class GraphManager(object):
|
||||
|
||||
return data_store_creator(param)
|
||||
|
||||
def signal_ready(self):
|
||||
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.TRAINER_READY.value), 'w').close()
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
data_store.save_to_store()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Clean up to close environments.
|
||||
|
||||
@@ -118,7 +118,7 @@ class Kubernetes(Deploy):
|
||||
self.s3_access_key = os.environ.get('ACCESS_KEY_ID')
|
||||
self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')
|
||||
|
||||
def setup(self) -> bool:
|
||||
def setup(self, crd=None) -> bool:
|
||||
"""
|
||||
Deploys the memory backend and data stores if required.
|
||||
"""
|
||||
@@ -128,6 +128,9 @@ class Kubernetes(Deploy):
|
||||
return False
|
||||
if self.params.data_store_params.store_type == "nfs":
|
||||
self.nfs_pvc = self.data_store.get_info()
|
||||
|
||||
# Upload checkpoints in checkpoint_restore_dir (if provided) to the data store
|
||||
self.data_store.setup_checkpoint_dir(crd)
|
||||
return True
|
||||
|
||||
def deploy_trainer(self) -> bool:
|
||||
@@ -141,7 +144,6 @@ class Kubernetes(Deploy):
|
||||
|
||||
trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
||||
|
||||
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
||||
|
||||
if self.params.data_store_params.store_type == "nfs":
|
||||
|
||||
@@ -33,30 +33,50 @@ from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
def wait_for(wait_func, data_store=None, timeout=10):
|
||||
"""
|
||||
block until wait_func is true
|
||||
"""
|
||||
for i in range(timeout):
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
|
||||
if wait_func():
|
||||
return
|
||||
time.sleep(10)
|
||||
|
||||
# one last time
|
||||
if wait_func():
|
||||
return
|
||||
|
||||
raise ValueError((
|
||||
'Waited {timeout} seconds, but condition timed out'
|
||||
).format(
|
||||
timeout=timeout,
|
||||
))
|
||||
|
||||
|
||||
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
|
||||
"""
|
||||
block until there is a checkpoint in checkpoint_dir
|
||||
"""
|
||||
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
|
||||
for i in range(timeout):
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
|
||||
if chkpt_state_file.read() is not None:
|
||||
return
|
||||
time.sleep(10)
|
||||
def wait():
|
||||
return chkpt_state_file.read() is not None
|
||||
|
||||
# one last time
|
||||
if chkpt_state_file.read() is not None:
|
||||
return
|
||||
wait_for(wait, data_store, timeout)
|
||||
|
||||
raise ValueError((
|
||||
'Waited {timeout} seconds, but checkpoint never found in '
|
||||
'{checkpoint_dir}'
|
||||
).format(
|
||||
timeout=timeout,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
))
|
||||
|
||||
def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
|
||||
"""
|
||||
Block until trainer is ready
|
||||
"""
|
||||
|
||||
def wait():
|
||||
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value))
|
||||
|
||||
wait_for(wait, data_store, timeout)
|
||||
|
||||
|
||||
def should_stop(checkpoint_dir):
|
||||
@@ -69,17 +89,18 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||
"""
|
||||
checkpoint_dir = task_parameters.checkpoint_restore_path
|
||||
wait_for_checkpoint(checkpoint_dir, data_store)
|
||||
wait_for_trainer_ready(checkpoint_dir, data_store)
|
||||
|
||||
graph_manager.create_graph(task_parameters)
|
||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||
|
||||
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
|
||||
last_checkpoint = 0
|
||||
last_checkpoint = chkpt_state_reader.get_latest().num
|
||||
|
||||
# this worker should play a fraction of the total playing steps per rollout
|
||||
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
|
||||
|
||||
for i in range(graph_manager.improve_steps / act_steps):
|
||||
training_steps = (graph_manager.improve_steps / act_steps.num_steps).num_steps
|
||||
for i in range(training_steps):
|
||||
|
||||
if should_stop(checkpoint_dir):
|
||||
break
|
||||
|
||||
@@ -24,24 +24,30 @@ from rl_coach import core_types
|
||||
from rl_coach.logger import screen
|
||||
|
||||
|
||||
def data_store_ckpt_save(data_store):
|
||||
while True:
|
||||
data_store.save_to_store()
|
||||
time.sleep(10)
|
||||
def data_store_ckpt_load(data_store):
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
|
||||
|
||||
def training_worker(graph_manager, task_parameters, is_multi_node_test):
|
||||
def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test):
|
||||
"""
|
||||
restore a checkpoint then perform rollouts using the restored model
|
||||
:param graph_manager: An instance of the graph manager
|
||||
:param task_parameters: An instance of task parameters
|
||||
:param is_multi_node_test: If this is a multi node test insted of a normal run.
|
||||
"""
|
||||
# initialize graph
|
||||
graph_manager.create_graph(task_parameters)
|
||||
# Load checkpoint if provided
|
||||
if task_parameters.checkpoint_restore_path:
|
||||
data_store_ckpt_load(data_store)
|
||||
# initialize graph
|
||||
graph_manager.create_graph(task_parameters)
|
||||
|
||||
# save randomly initialized graph
|
||||
graph_manager.save_checkpoint()
|
||||
else:
|
||||
# initialize graph
|
||||
graph_manager.create_graph(task_parameters)
|
||||
|
||||
# save randomly initialized graph
|
||||
graph_manager.save_checkpoint()
|
||||
|
||||
# training loop
|
||||
steps = 0
|
||||
@@ -50,6 +56,7 @@ def training_worker(graph_manager, task_parameters, is_multi_node_test):
|
||||
eval_offset = 1
|
||||
|
||||
graph_manager.setup_memory_backend()
|
||||
graph_manager.signal_ready()
|
||||
|
||||
while steps < graph_manager.improve_steps.num_steps:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user