1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Adding should_train helper and should_train in graph_manager

This commit is contained in:
Ajay Deshpande
2018-10-05 14:22:15 -07:00
committed by zach dwiel
parent a2e57a44f1
commit a7f5442015
7 changed files with 126 additions and 20 deletions

View File

@@ -5,6 +5,7 @@ from minio.error import ResponseError
from configparser import ConfigParser, Error
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from minio.error import ResponseError
import os
import time
import io
@@ -63,7 +64,7 @@ class S3DataStore(DataStore):
for filename in files:
if filename == 'checkpoint':
checkpoint_file = (root, filename)
pass
continue
abs_name = os.path.abspath(os.path.join(root, filename))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
@@ -79,17 +80,21 @@ class S3DataStore(DataStore):
def load_from_store(self):
try:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
while True:
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
time.sleep(10)
if next(objects, None) is None:
try:
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
except ResponseError as e:
continue
break
time.sleep(10)
print("loading from s3")
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
ckpt = CheckpointState()
if os.path.exists(filename):
contents = open(filename, 'r').read()