mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Uploading checkpoint if crd provided (#191)
* Uploading checkpoint if crd provided * Changing the calculation of total steps because of a recent change in core_types Fixes #195
This commit is contained in:
committed by
Scott Leishman
parent
b3db9ce77d
commit
33dc29ee99
@@ -103,21 +103,22 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
|||||||
data_store_params.checkpoint_dir = ckpt_inside_container
|
data_store_params.checkpoint_dir = ckpt_inside_container
|
||||||
graph_manager.data_store_params = data_store_params
|
graph_manager.data_store_params = data_store_params
|
||||||
|
|
||||||
|
data_store = None
|
||||||
|
if args.data_store_params:
|
||||||
|
data_store = get_data_store(data_store_params)
|
||||||
|
|
||||||
if args.distributed_coach_run_type == RunType.TRAINER:
|
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||||
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||||
training_worker(
|
training_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
task_parameters=task_parameters,
|
task_parameters=task_parameters,
|
||||||
|
data_store=data_store,
|
||||||
is_multi_node_test=args.is_multi_node_test
|
is_multi_node_test=args.is_multi_node_test
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||||
task_parameters.checkpoint_restore_path = ckpt_inside_container
|
task_parameters.checkpoint_restore_path = ckpt_inside_container
|
||||||
|
|
||||||
data_store = None
|
|
||||||
if args.data_store_params:
|
|
||||||
data_store = get_data_store(data_store_params)
|
|
||||||
|
|
||||||
rollout_worker(
|
rollout_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
data_store=data_store,
|
data_store=data_store,
|
||||||
@@ -169,7 +170,7 @@ def handle_distributed_coach_orchestrator(args):
|
|||||||
memory_backend_parameters=memory_backend_params,
|
memory_backend_parameters=memory_backend_params,
|
||||||
data_store_params=ds_params_instance)
|
data_store_params=ds_params_instance)
|
||||||
orchestrator = Kubernetes(orchestration_params)
|
orchestrator = Kubernetes(orchestration_params)
|
||||||
if not orchestrator.setup():
|
if not orchestrator.setup(args.checkpoint_restore_dir):
|
||||||
print("Could not setup.")
|
print("Could not setup.")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
@@ -394,7 +395,9 @@ class CoachLauncher(object):
|
|||||||
|
|
||||||
# validate the checkpoints args
|
# validate the checkpoints args
|
||||||
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
||||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
# If distributed trainer, the checkpoint dir is not yet available so skipping the check in that case.
|
||||||
|
if not (args.distributed_coach and args.distributed_coach_run_type in [RunType.TRAINER, RunType.ROLLOUT_WORKER]):
|
||||||
|
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||||
|
|
||||||
# validate the checkpoints args
|
# validate the checkpoints args
|
||||||
if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'):
|
if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'):
|
||||||
|
|||||||
@@ -44,7 +44,11 @@ class DataStore(object):
|
|||||||
def load_from_store(self):
|
def load_from_store(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def setup_checkpoint_dir(self, crd=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SyncFiles(Enum):
|
class SyncFiles(Enum):
|
||||||
FINISHED = ".finished"
|
FINISHED = ".finished"
|
||||||
LOCKFILE = ".lock"
|
LOCKFILE = ".lock"
|
||||||
|
TRAINER_READY = ".ready"
|
||||||
|
|||||||
@@ -284,3 +284,8 @@ class NFSDataStore(DataStore):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def setup_checkpoint_dir(self, crd=None):
|
||||||
|
if crd:
|
||||||
|
# TODO: find a way to upload this to the deployed nfs store.
|
||||||
|
pass
|
||||||
|
|||||||
@@ -77,6 +77,9 @@ class S3DataStore(DataStore):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def save_to_store(self):
|
def save_to_store(self):
|
||||||
|
self._save_to_store(self.params.checkpoint_dir)
|
||||||
|
|
||||||
|
def _save_to_store(self, checkpoint_dir):
|
||||||
"""
|
"""
|
||||||
save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
|
save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
|
||||||
uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
|
uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
|
||||||
@@ -88,24 +91,32 @@ class S3DataStore(DataStore):
|
|||||||
# Acquire lock
|
# Acquire lock
|
||||||
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
|
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
|
||||||
|
|
||||||
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
|
state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
|
||||||
if state_file.exists():
|
if state_file.exists():
|
||||||
ckpt_state = state_file.read()
|
ckpt_state = state_file.read()
|
||||||
checkpoint_file = None
|
checkpoint_file = None
|
||||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
for root, dirs, files in os.walk(checkpoint_dir):
|
||||||
for filename in files:
|
for filename in files:
|
||||||
if filename == CheckpointStateFile.checkpoint_state_filename:
|
if filename == CheckpointStateFile.checkpoint_state_filename:
|
||||||
checkpoint_file = (root, filename)
|
checkpoint_file = (root, filename)
|
||||||
continue
|
continue
|
||||||
if filename.startswith(ckpt_state.name):
|
if filename.startswith(ckpt_state.name):
|
||||||
abs_name = os.path.abspath(os.path.join(root, filename))
|
abs_name = os.path.abspath(os.path.join(root, filename))
|
||||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
rel_name = os.path.relpath(abs_name, checkpoint_dir)
|
||||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||||
|
|
||||||
abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
|
abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
|
||||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
rel_name = os.path.relpath(abs_name, checkpoint_dir)
|
||||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||||
|
|
||||||
|
# upload Finished if present
|
||||||
|
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
|
||||||
|
self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0)
|
||||||
|
|
||||||
|
# upload Ready if present
|
||||||
|
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
|
||||||
|
self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0)
|
||||||
|
|
||||||
# release lock
|
# release lock
|
||||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||||
|
|
||||||
@@ -121,6 +132,7 @@ class S3DataStore(DataStore):
|
|||||||
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')):
|
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')):
|
||||||
for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')):
|
for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')):
|
||||||
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
|
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
|
||||||
|
|
||||||
except ResponseError as e:
|
except ResponseError as e:
|
||||||
print("Got exception: %s\n while saving to S3", e)
|
print("Got exception: %s\n while saving to S3", e)
|
||||||
|
|
||||||
@@ -157,6 +169,18 @@ class S3DataStore(DataStore):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Check if there's a ready file
|
||||||
|
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value)
|
||||||
|
|
||||||
|
if next(objects, None) is not None:
|
||||||
|
try:
|
||||||
|
self.mc.fget_object(
|
||||||
|
self.params.bucket_name, SyncFiles.TRAINER_READY.value,
|
||||||
|
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value))
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
checkpoint_state = state_file.read()
|
checkpoint_state = state_file.read()
|
||||||
if checkpoint_state is not None:
|
if checkpoint_state is not None:
|
||||||
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
|
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
|
||||||
@@ -167,3 +191,7 @@ class S3DataStore(DataStore):
|
|||||||
|
|
||||||
except ResponseError as e:
|
except ResponseError as e:
|
||||||
print("Got exception: %s\n while loading from S3", e)
|
print("Got exception: %s\n while loading from S3", e)
|
||||||
|
|
||||||
|
def setup_checkpoint_dir(self, crd=None):
|
||||||
|
if crd:
|
||||||
|
self._save_to_store(crd)
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
|||||||
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||||
from rl_coach.data_stores.data_store import SyncFiles
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
from rl_coach.checkpoint import CheckpointStateReader
|
||||||
|
|
||||||
from rl_coach.core_types import TimeTypes
|
from rl_coach.core_types import TimeTypes
|
||||||
|
|
||||||
@@ -589,6 +590,10 @@ class GraphManager(object):
|
|||||||
|
|
||||||
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
|
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
|
||||||
|
|
||||||
|
# Set the last checkpoint ID
|
||||||
|
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False)
|
||||||
|
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
|
||||||
|
|
||||||
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
|
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
|
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
|
||||||
@@ -721,6 +726,13 @@ class GraphManager(object):
|
|||||||
|
|
||||||
return data_store_creator(param)
|
return data_store_creator(param)
|
||||||
|
|
||||||
|
def signal_ready(self):
|
||||||
|
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||||
|
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.TRAINER_READY.value), 'w').close()
|
||||||
|
if hasattr(self, 'data_store_params'):
|
||||||
|
data_store = self.get_data_store(self.data_store_params)
|
||||||
|
data_store.save_to_store()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""
|
"""
|
||||||
Clean up to close environments.
|
Clean up to close environments.
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ class Kubernetes(Deploy):
|
|||||||
self.s3_access_key = os.environ.get('ACCESS_KEY_ID')
|
self.s3_access_key = os.environ.get('ACCESS_KEY_ID')
|
||||||
self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')
|
self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')
|
||||||
|
|
||||||
def setup(self) -> bool:
|
def setup(self, crd=None) -> bool:
|
||||||
"""
|
"""
|
||||||
Deploys the memory backend and data stores if required.
|
Deploys the memory backend and data stores if required.
|
||||||
"""
|
"""
|
||||||
@@ -128,6 +128,9 @@ class Kubernetes(Deploy):
|
|||||||
return False
|
return False
|
||||||
if self.params.data_store_params.store_type == "nfs":
|
if self.params.data_store_params.store_type == "nfs":
|
||||||
self.nfs_pvc = self.data_store.get_info()
|
self.nfs_pvc = self.data_store.get_info()
|
||||||
|
|
||||||
|
# Upload checkpoints in checkpoint_restore_dir (if provided) to the data store
|
||||||
|
self.data_store.setup_checkpoint_dir(crd)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def deploy_trainer(self) -> bool:
|
def deploy_trainer(self) -> bool:
|
||||||
@@ -141,7 +144,6 @@ class Kubernetes(Deploy):
|
|||||||
|
|
||||||
trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||||
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
||||||
|
|
||||||
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
||||||
|
|
||||||
if self.params.data_store_params.store_type == "nfs":
|
if self.params.data_store_params.store_type == "nfs":
|
||||||
|
|||||||
@@ -33,30 +33,50 @@ from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
|||||||
from rl_coach.data_stores.data_store import SyncFiles
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for(wait_func, data_store=None, timeout=10):
|
||||||
|
"""
|
||||||
|
block until wait_func is true
|
||||||
|
"""
|
||||||
|
for i in range(timeout):
|
||||||
|
if data_store:
|
||||||
|
data_store.load_from_store()
|
||||||
|
|
||||||
|
if wait_func():
|
||||||
|
return
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
# one last time
|
||||||
|
if wait_func():
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError((
|
||||||
|
'Waited {timeout} seconds, but condition timed out'
|
||||||
|
).format(
|
||||||
|
timeout=timeout,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
|
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
|
||||||
"""
|
"""
|
||||||
block until there is a checkpoint in checkpoint_dir
|
block until there is a checkpoint in checkpoint_dir
|
||||||
"""
|
"""
|
||||||
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
|
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
|
||||||
for i in range(timeout):
|
|
||||||
if data_store:
|
|
||||||
data_store.load_from_store()
|
|
||||||
|
|
||||||
if chkpt_state_file.read() is not None:
|
def wait():
|
||||||
return
|
return chkpt_state_file.read() is not None
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
# one last time
|
wait_for(wait, data_store, timeout)
|
||||||
if chkpt_state_file.read() is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
raise ValueError((
|
|
||||||
'Waited {timeout} seconds, but checkpoint never found in '
|
def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
|
||||||
'{checkpoint_dir}'
|
"""
|
||||||
).format(
|
Block until trainer is ready
|
||||||
timeout=timeout,
|
"""
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
))
|
def wait():
|
||||||
|
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value))
|
||||||
|
|
||||||
|
wait_for(wait, data_store, timeout)
|
||||||
|
|
||||||
|
|
||||||
def should_stop(checkpoint_dir):
|
def should_stop(checkpoint_dir):
|
||||||
@@ -69,17 +89,18 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
|||||||
"""
|
"""
|
||||||
checkpoint_dir = task_parameters.checkpoint_restore_path
|
checkpoint_dir = task_parameters.checkpoint_restore_path
|
||||||
wait_for_checkpoint(checkpoint_dir, data_store)
|
wait_for_checkpoint(checkpoint_dir, data_store)
|
||||||
|
wait_for_trainer_ready(checkpoint_dir, data_store)
|
||||||
|
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||||
|
|
||||||
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
|
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
|
||||||
last_checkpoint = 0
|
last_checkpoint = chkpt_state_reader.get_latest().num
|
||||||
|
|
||||||
# this worker should play a fraction of the total playing steps per rollout
|
# this worker should play a fraction of the total playing steps per rollout
|
||||||
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
|
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
|
||||||
|
training_steps = (graph_manager.improve_steps / act_steps.num_steps).num_steps
|
||||||
for i in range(graph_manager.improve_steps / act_steps):
|
for i in range(training_steps):
|
||||||
|
|
||||||
if should_stop(checkpoint_dir):
|
if should_stop(checkpoint_dir):
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -24,24 +24,30 @@ from rl_coach import core_types
|
|||||||
from rl_coach.logger import screen
|
from rl_coach.logger import screen
|
||||||
|
|
||||||
|
|
||||||
def data_store_ckpt_save(data_store):
|
def data_store_ckpt_load(data_store):
|
||||||
while True:
|
if data_store:
|
||||||
data_store.save_to_store()
|
data_store.load_from_store()
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
|
|
||||||
def training_worker(graph_manager, task_parameters, is_multi_node_test):
|
def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test):
|
||||||
"""
|
"""
|
||||||
restore a checkpoint then perform rollouts using the restored model
|
restore a checkpoint then perform rollouts using the restored model
|
||||||
:param graph_manager: An instance of the graph manager
|
:param graph_manager: An instance of the graph manager
|
||||||
:param task_parameters: An instance of task parameters
|
:param task_parameters: An instance of task parameters
|
||||||
:param is_multi_node_test: If this is a multi node test insted of a normal run.
|
:param is_multi_node_test: If this is a multi node test insted of a normal run.
|
||||||
"""
|
"""
|
||||||
# initialize graph
|
# Load checkpoint if provided
|
||||||
graph_manager.create_graph(task_parameters)
|
if task_parameters.checkpoint_restore_path:
|
||||||
|
data_store_ckpt_load(data_store)
|
||||||
|
# initialize graph
|
||||||
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|
||||||
# save randomly initialized graph
|
else:
|
||||||
graph_manager.save_checkpoint()
|
# initialize graph
|
||||||
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|
||||||
|
# save randomly initialized graph
|
||||||
|
graph_manager.save_checkpoint()
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
steps = 0
|
steps = 0
|
||||||
@@ -50,6 +56,7 @@ def training_worker(graph_manager, task_parameters, is_multi_node_test):
|
|||||||
eval_offset = 1
|
eval_offset = 1
|
||||||
|
|
||||||
graph_manager.setup_memory_backend()
|
graph_manager.setup_memory_backend()
|
||||||
|
graph_manager.signal_ready()
|
||||||
|
|
||||||
while steps < graph_manager.improve_steps.num_steps:
|
while steps < graph_manager.improve_steps.num_steps:
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user