mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Uploading checkpoint if crd provided (#191)
* Uploading checkpoint if crd provided * Changing the calculation of total steps because of a recent change in core_types Fixes #195
This commit is contained in:
committed by
Scott Leishman
parent
b3db9ce77d
commit
33dc29ee99
@@ -77,6 +77,9 @@ class S3DataStore(DataStore):
|
||||
return True
|
||||
|
||||
def save_to_store(self):
|
||||
self._save_to_store(self.params.checkpoint_dir)
|
||||
|
||||
def _save_to_store(self, checkpoint_dir):
|
||||
"""
|
||||
save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
|
||||
uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
|
||||
@@ -88,24 +91,32 @@ class S3DataStore(DataStore):
|
||||
# Acquire lock
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
|
||||
|
||||
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
|
||||
state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
|
||||
if state_file.exists():
|
||||
ckpt_state = state_file.read()
|
||||
checkpoint_file = None
|
||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
||||
for root, dirs, files in os.walk(checkpoint_dir):
|
||||
for filename in files:
|
||||
if filename == CheckpointStateFile.checkpoint_state_filename:
|
||||
checkpoint_file = (root, filename)
|
||||
continue
|
||||
if filename.startswith(ckpt_state.name):
|
||||
abs_name = os.path.abspath(os.path.join(root, filename))
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
rel_name = os.path.relpath(abs_name, checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
|
||||
abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
rel_name = os.path.relpath(abs_name, checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
|
||||
# upload Finished if present
|
||||
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0)
|
||||
|
||||
# upload Ready if present
|
||||
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0)
|
||||
|
||||
# release lock
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
@@ -121,6 +132,7 @@ class S3DataStore(DataStore):
|
||||
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')):
|
||||
for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')):
|
||||
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
|
||||
|
||||
except ResponseError as e:
|
||||
print("Got exception: %s\n while saving to S3", e)
|
||||
|
||||
@@ -157,6 +169,18 @@ class S3DataStore(DataStore):
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Check if there's a ready file
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value)
|
||||
|
||||
if next(objects, None) is not None:
|
||||
try:
|
||||
self.mc.fget_object(
|
||||
self.params.bucket_name, SyncFiles.TRAINER_READY.value,
|
||||
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
checkpoint_state = state_file.read()
|
||||
if checkpoint_state is not None:
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
|
||||
@@ -167,3 +191,7 @@ class S3DataStore(DataStore):
|
||||
|
||||
except ResponseError as e:
|
||||
print("Got exception: %s\n while loading from S3", e)
|
||||
|
||||
def setup_checkpoint_dir(self, crd=None):
|
||||
if crd:
|
||||
self._save_to_store(crd)
|
||||
|
||||
Reference in New Issue
Block a user