# # Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from rl_coach.data_stores.data_store import DataStoreParameters from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore from minio import Minio from minio.error import ResponseError from configparser import ConfigParser, Error from rl_coach.checkpoint import CheckpointStateFile from rl_coach.data_stores.data_store import SyncFiles import os import time import io class S3DataStoreParameters(DataStoreParameters): def __init__(self, ds_params, creds_file: str = None, end_point: str = None, bucket_name: str = None, checkpoint_dir: str = None, expt_dir: str = None): super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params) self.creds_file = creds_file self.end_point = end_point self.bucket_name = bucket_name self.checkpoint_dir = checkpoint_dir self.expt_dir = expt_dir class S3DataStore(CheckpointDataStore): """ An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode. The policy checkpoints are written by the trainer and read by the rollout worker. """ def __init__(self, params: S3DataStoreParameters): """ :param params: The parameters required to use the S3 data store. """ super(S3DataStore, self).__init__(params) self.params = params access_key = None secret_key = None if params.creds_file: config = ConfigParser() config.read(params.creds_file) try: access_key = config.get('default', 'aws_access_key_id') secret_key = config.get('default', 'aws_secret_access_key') except Error as e: print("Error when reading S3 credentials file: %s", e) else: access_key = os.environ.get('ACCESS_KEY_ID') secret_key = os.environ.get('SECRET_ACCESS_KEY') self.mc = Minio(self.params.end_point, access_key=access_key, secret_key=secret_key) def deploy(self) -> bool: return True def get_info(self): return "s3://{}/{}".format(self.params.bucket_name) def undeploy(self) -> bool: return True def save_to_store(self): self._save_to_store(self.params.checkpoint_dir) def _save_to_store(self, checkpoint_dir): """ save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode. """ try: # remove lock file if it exists self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) # Acquire lock self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0) state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir)) if state_file.exists(): ckpt_state = state_file.read() checkpoint_file = None for root, dirs, files in os.walk(checkpoint_dir): for filename in files: if filename == CheckpointStateFile.checkpoint_state_filename: checkpoint_file = (root, filename) continue if filename.startswith(ckpt_state.name): abs_name = os.path.abspath(os.path.join(root, filename)) rel_name = os.path.relpath(abs_name, checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1])) rel_name = os.path.relpath(abs_name, checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) # upload Finished if present if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)): self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0) # upload Ready if present if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)): self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0) # release lock self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) if self.params.expt_dir and os.path.exists(self.params.expt_dir): for filename in os.listdir(self.params.expt_dir): if filename.endswith((".csv", ".json")): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, filename)) if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'videos')): for filename in os.listdir(os.path.join(self.params.expt_dir, 'videos')): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'videos', filename)) if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')): for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename)) except ResponseError as e: print("Got exception: %s\n while saving to S3", e) def load_from_store(self): """ load_from_store() downloads a new checkpoint from the S3 data store when it is not available locally. It is used by the rollout workers when using Coach in distributed mode. """ try: state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) # wait until lock is removed while True: objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value) if next(objects, None) is None: try: # fetch checkpoint state file from S3 self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path) except Exception as e: continue break time.sleep(10) # Check if there's a finished file objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value) if next(objects, None) is not None: try: self.mc.fget_object( self.params.bucket_name, SyncFiles.FINISHED.value, os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value)) ) except Exception as e: pass # Check if there's a ready file objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value) if next(objects, None) is not None: try: self.mc.fget_object( self.params.bucket_name, SyncFiles.TRAINER_READY.value, os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value)) ) except Exception as e: pass checkpoint_state = state_file.read() if checkpoint_state is not None: objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True) for obj in objects: filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) if not os.path.exists(filename): self.mc.fget_object(obj.bucket_name, obj.object_name, filename) except ResponseError as e: print("Got exception: %s\n while loading from S3", e) def setup_checkpoint_dir(self, crd=None): if crd: self._save_to_store(crd)