mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
97 lines
3.7 KiB
Python
97 lines
3.7 KiB
Python
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
|
from kubernetes import client as k8sclient
|
|
from minio import Minio
|
|
from minio.error import ResponseError
|
|
from configparser import ConfigParser, Error
|
|
import os
|
|
import time
|
|
import io
|
|
|
|
|
|
class S3DataStoreParameters(DataStoreParameters):
|
|
def __init__(self, ds_params, creds_file: str = None, end_point: str = None, bucket_name: str = None,
|
|
checkpoint_dir: str = None):
|
|
|
|
super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params)
|
|
self.creds_file = creds_file
|
|
self.end_point = end_point
|
|
self.bucket_name = bucket_name
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.lock_file = ".lock"
|
|
|
|
|
|
class S3DataStore(DataStore):
|
|
def __init__(self, params: S3DataStoreParameters):
|
|
self.params = params
|
|
access_key = None
|
|
secret_key = None
|
|
if params.creds_file:
|
|
config = ConfigParser()
|
|
config.read(params.creds_file)
|
|
try:
|
|
access_key = config.get('default', 'aws_access_key_id')
|
|
secret_key = config.get('default', 'aws_secret_access_key')
|
|
except Error as e:
|
|
print("Error when reading S3 credentials file: %s", e)
|
|
else:
|
|
access_key = os.environ.get('ACCESS_KEY_ID')
|
|
secret_key = os.environ.get('SECRET_ACCESS_KEY')
|
|
self.mc = Minio(self.params.end_point, access_key=access_key, secret_key=secret_key)
|
|
|
|
def deploy(self) -> bool:
|
|
return True
|
|
|
|
def get_info(self):
|
|
return "s3://{}/{}".format(self.params.bucket_name)
|
|
|
|
def undeploy(self) -> bool:
|
|
return True
|
|
|
|
def save_to_store(self):
|
|
try:
|
|
print("Writing lock file")
|
|
|
|
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
|
|
|
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
|
|
|
|
print("saving to s3")
|
|
checkpoint_file = None
|
|
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
|
for filename in files:
|
|
if filename == 'checkpoint':
|
|
checkpoint_file = (root, filename)
|
|
pass
|
|
abs_name = os.path.abspath(os.path.join(root, filename))
|
|
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
|
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
|
|
|
abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
|
|
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
|
print("Deleting lock file")
|
|
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
|
|
|
except ResponseError as e:
|
|
print("Got exception: %s\n while saving to S3", e)
|
|
|
|
def load_from_store(self):
|
|
try:
|
|
|
|
while True:
|
|
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
|
|
time.sleep(10)
|
|
if next(objects, None) is None:
|
|
break
|
|
|
|
print("loading from s3")
|
|
|
|
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
|
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
|
|
|
|
objects = self.mc.list_objects_v2(self.params.bucket_name, recursive=True)
|
|
for obj in objects:
|
|
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
|
|
self.mc.fget_object(obj.bucket_name, obj.object_name, filename)
|
|
except ResponseError as e:
|
|
print("Got exception: %s\n while loading from S3", e)
|