From 052bbc8f1932e886dfb767e4cd75b7433f69bbb9 Mon Sep 17 00:00:00 2001 From: Ajay Deshpande Date: Fri, 5 Oct 2018 12:53:51 -0700 Subject: [PATCH] Adding lock in s3 --- rl_coach/data_stores/s3_data_store.py | 30 +++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 5ceb00a..d616d9c 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -4,6 +4,8 @@ from minio import Minio from minio.error import ResponseError from configparser import ConfigParser, Error import os +import time +import io class S3DataStoreParameters(DataStoreParameters): @@ -15,6 +17,7 @@ class S3DataStoreParameters(DataStoreParameters): self.end_point = end_point self.bucket_name = bucket_name self.checkpoint_dir = checkpoint_dir + self.lock_file = ".lock" class S3DataStore(DataStore): @@ -46,18 +49,45 @@ class S3DataStore(DataStore): def save_to_store(self): try: + print("Writing lock file") + + self.mc.remove_object(self.params.bucket_name, self.params.lock_file) + + self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0) + print("saving to s3") + checkpoint_file = None for root, dirs, files in os.walk(self.params.checkpoint_dir): for filename in files: + if filename == 'checkpoint': + checkpoint_file = (root, filename) + pass abs_name = os.path.abspath(os.path.join(root, filename)) rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) + + abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1])) + rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) + print("Deleting lock file") + self.mc.remove_object(self.params.bucket_name, self.params.lock_file) + except ResponseError as e: print("Got exception: %s\n while saving to S3", e) def load_from_store(self): try: + + while True: + objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file) + time.sleep(10) + if next(objects, None) is None: + break + print("loading from s3") + + filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) + self.mc.fget_object(self.params.bucket_name, "checkpoint", filename) + objects = self.mc.list_objects_v2(self.params.bucket_name, recursive=True) for obj in objects: filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))