diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index d616d9c..2cb838b 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -3,6 +3,8 @@ from kubernetes import client as k8sclient from minio import Minio from minio.error import ResponseError from configparser import ConfigParser, Error +from google.protobuf import text_format +from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState import os import time import io @@ -88,9 +90,16 @@ class S3DataStore(DataStore): filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) self.mc.fget_object(self.params.bucket_name, "checkpoint", filename) - objects = self.mc.list_objects_v2(self.params.bucket_name, recursive=True) - for obj in objects: - filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) - self.mc.fget_object(obj.bucket_name, obj.object_name, filename) + ckpt = CheckpointState() + if os.path.exists(filename): + contents = open(filename, 'r').read() + text_format.Merge(contents, ckpt) + rel_path = os.path.relpath(ckpt.model_checkpoint_path, self.params.checkpoint_dir) + + objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=rel_path, recursive=True) + for obj in objects: + filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) + self.mc.fget_object(obj.bucket_name, obj.object_name, filename) + except ResponseError as e: print("Got exception: %s\n while loading from S3", e)