diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 01eaf55..a691bb0 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -54,22 +54,27 @@ class S3DataStore(DataStore): try: # remove lock file if it exists self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) + # Acquire lock self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0) - checkpoint_file = None - for root, dirs, files in os.walk(self.params.checkpoint_dir): - for filename in files: - if filename == CheckpointStateFile.checkpoint_state_filename: - checkpoint_file = (root, filename) - continue - abs_name = os.path.abspath(os.path.join(root, filename)) - rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) - self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) + state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) + if state_file.exists(): + ckpt_state = state_file.read() + checkpoint_file = None + for root, dirs, files in os.walk(self.params.checkpoint_dir): + for filename in files: + if filename == CheckpointStateFile.checkpoint_state_filename: + checkpoint_file = (root, filename) + continue + if filename.startswith(ckpt_state.name): + abs_name = os.path.abspath(os.path.join(root, filename)) + rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) + self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) - abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1])) - rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) - self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) + abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1])) + rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) + self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) # release lock self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)