1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Getting only the model_checkpoint_path files

This commit is contained in:
Ajay Deshpande
2018-10-05 13:48:10 -07:00
committed by zach dwiel
parent 052bbc8f19
commit a2e57a44f1

View File

@@ -3,6 +3,8 @@ from kubernetes import client as k8sclient
from minio import Minio from minio import Minio
from minio.error import ResponseError from minio.error import ResponseError
from configparser import ConfigParser, Error from configparser import ConfigParser, Error
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
import os import os
import time import time
import io import io
@@ -88,9 +90,16 @@ class S3DataStore(DataStore):
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename) self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
objects = self.mc.list_objects_v2(self.params.bucket_name, recursive=True) ckpt = CheckpointState()
for obj in objects: if os.path.exists(filename):
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) contents = open(filename, 'r').read()
self.mc.fget_object(obj.bucket_name, obj.object_name, filename) text_format.Merge(contents, ckpt)
rel_path = os.path.relpath(ckpt.model_checkpoint_path, self.params.checkpoint_dir)
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=rel_path, recursive=True)
for obj in objects:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
self.mc.fget_object(obj.bucket_name, obj.object_name, filename)
except ResponseError as e: except ResponseError as e:
print("Got exception: %s\n while loading from S3", e) print("Got exception: %s\n while loading from S3", e)