1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Uploading checkpoint if crd provided (#191)

* Uploading checkpoint if crd provided
* Changing the calculation of total steps because of a recent change in core_types

Fixes #195
This commit is contained in:
Ajay Deshpande
2019-04-26 12:27:33 -07:00
committed by Scott Leishman
parent b3db9ce77d
commit 33dc29ee99
8 changed files with 122 additions and 40 deletions

View File

@@ -77,6 +77,9 @@ class S3DataStore(DataStore):
return True
def save_to_store(self):
self._save_to_store(self.params.checkpoint_dir)
def _save_to_store(self, checkpoint_dir):
"""
save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
@@ -88,24 +91,32 @@ class S3DataStore(DataStore):
# Acquire lock
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
if state_file.exists():
ckpt_state = state_file.read()
checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir):
for root, dirs, files in os.walk(checkpoint_dir):
for filename in files:
if filename == CheckpointStateFile.checkpoint_state_filename:
checkpoint_file = (root, filename)
continue
if filename.startswith(ckpt_state.name):
abs_name = os.path.abspath(os.path.join(root, filename))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
rel_name = os.path.relpath(abs_name, checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
rel_name = os.path.relpath(abs_name, checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
# upload Finished if present
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0)
# upload Ready if present
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0)
# release lock
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
@@ -121,6 +132,7 @@ class S3DataStore(DataStore):
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')):
for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')):
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
except ResponseError as e:
print("Got exception: %s\n while saving to S3", e)
@@ -157,6 +169,18 @@ class S3DataStore(DataStore):
except Exception as e:
pass
# Check if there's a ready file
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value)
if next(objects, None) is not None:
try:
self.mc.fget_object(
self.params.bucket_name, SyncFiles.TRAINER_READY.value,
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value))
)
except Exception as e:
pass
checkpoint_state = state_file.read()
if checkpoint_state is not None:
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
@@ -167,3 +191,7 @@ class S3DataStore(DataStore):
except ResponseError as e:
print("Got exception: %s\n while loading from S3", e)
def setup_checkpoint_dir(self, crd=None):
if crd:
self._save_to_store(crd)