mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Getting only the model_checkpoint_path files
This commit is contained in:
committed by
zach dwiel
parent
052bbc8f19
commit
a2e57a44f1
@@ -3,6 +3,8 @@ from kubernetes import client as k8sclient
|
|||||||
from minio import Minio
|
from minio import Minio
|
||||||
from minio.error import ResponseError
|
from minio.error import ResponseError
|
||||||
from configparser import ConfigParser, Error
|
from configparser import ConfigParser, Error
|
||||||
|
from google.protobuf import text_format
|
||||||
|
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import io
|
import io
|
||||||
@@ -88,9 +90,16 @@ class S3DataStore(DataStore):
|
|||||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||||
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
|
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
|
||||||
|
|
||||||
objects = self.mc.list_objects_v2(self.params.bucket_name, recursive=True)
|
ckpt = CheckpointState()
|
||||||
for obj in objects:
|
if os.path.exists(filename):
|
||||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
|
contents = open(filename, 'r').read()
|
||||||
self.mc.fget_object(obj.bucket_name, obj.object_name, filename)
|
text_format.Merge(contents, ckpt)
|
||||||
|
rel_path = os.path.relpath(ckpt.model_checkpoint_path, self.params.checkpoint_dir)
|
||||||
|
|
||||||
|
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=rel_path, recursive=True)
|
||||||
|
for obj in objects:
|
||||||
|
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
|
||||||
|
self.mc.fget_object(obj.bucket_name, obj.object_name, filename)
|
||||||
|
|
||||||
except ResponseError as e:
|
except ResponseError as e:
|
||||||
print("Got exception: %s\n while loading from S3", e)
|
print("Got exception: %s\n while loading from S3", e)
|
||||||
|
|||||||
Reference in New Issue
Block a user