mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Implement frame-work agnostic rollout and training workers (#137)
* Added checkpoint state file to coach checkpointing. * Removed TF specific code from rollout_worker, training_worker, and s3_data_store
This commit is contained in:
committed by
Balaji Subramaniam
parent
4a6c404070
commit
5332013bd1
@@ -2,8 +2,7 @@ from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
||||
from minio import Minio
|
||||
from minio.error import ResponseError
|
||||
from configparser import ConfigParser, Error
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from rl_coach.checkpoint import CheckpointStateFile
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
import os
|
||||
@@ -24,6 +23,7 @@ class S3DataStoreParameters(DataStoreParameters):
|
||||
|
||||
class S3DataStore(DataStore):
|
||||
def __init__(self, params: S3DataStoreParameters):
|
||||
super(S3DataStore, self).__init__(params)
|
||||
self.params = params
|
||||
access_key = None
|
||||
secret_key = None
|
||||
@@ -51,14 +51,15 @@ class S3DataStore(DataStore):
|
||||
|
||||
def save_to_store(self):
|
||||
try:
|
||||
# remove lock file if it exists
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
# Acquire lock
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
|
||||
|
||||
checkpoint_file = None
|
||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
||||
for filename in files:
|
||||
if filename == 'checkpoint':
|
||||
if filename == CheckpointStateFile.checkpoint_state_filename:
|
||||
checkpoint_file = (root, filename)
|
||||
continue
|
||||
abs_name = os.path.abspath(os.path.join(root, filename))
|
||||
@@ -69,6 +70,7 @@ class S3DataStore(DataStore):
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
|
||||
# release lock
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
except ResponseError as e:
|
||||
@@ -76,14 +78,16 @@ class S3DataStore(DataStore):
|
||||
|
||||
def load_from_store(self):
|
||||
try:
|
||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
|
||||
|
||||
# wait until lock is removed
|
||||
while True:
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
if next(objects, None) is None:
|
||||
try:
|
||||
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
|
||||
# fetch checkpoint state file from S3
|
||||
self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path)
|
||||
except Exception as e:
|
||||
continue
|
||||
break
|
||||
@@ -101,13 +105,9 @@ class S3DataStore(DataStore):
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
ckpt = CheckpointState()
|
||||
if os.path.exists(filename):
|
||||
contents = open(filename, 'r').read()
|
||||
text_format.Merge(contents, ckpt)
|
||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, self.params.checkpoint_dir)
|
||||
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=rel_path, recursive=True)
|
||||
checkpoint_state = state_file.read()
|
||||
if checkpoint_state is not None:
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
|
||||
for obj in objects:
|
||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
|
||||
if not os.path.exists(filename):
|
||||
|
||||
Reference in New Issue
Block a user