mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Add documentation on distributed Coach. (#158)
* Added documentation on distributed Coach.
This commit is contained in:
committed by
Gal Novik
parent
e3ecf445e2
commit
d06197f663
@@ -23,7 +23,16 @@ class S3DataStoreParameters(DataStoreParameters):
|
||||
|
||||
|
||||
class S3DataStore(DataStore):
|
||||
"""
|
||||
An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.
|
||||
The policy checkpoints are written by the trainer and read by the rollout worker.
|
||||
"""
|
||||
|
||||
def __init__(self, params: S3DataStoreParameters):
|
||||
"""
|
||||
:param params: The parameters required to use the S3 data store.
|
||||
"""
|
||||
|
||||
super(S3DataStore, self).__init__(params)
|
||||
self.params = params
|
||||
access_key = None
|
||||
@@ -51,6 +60,10 @@ class S3DataStore(DataStore):
|
||||
return True
|
||||
|
||||
def save_to_store(self):
|
||||
"""
|
||||
save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
|
||||
uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
|
||||
"""
|
||||
try:
|
||||
# remove lock file if it exists
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
@@ -95,6 +108,10 @@ class S3DataStore(DataStore):
|
||||
print("Got exception: %s\n while saving to S3", e)
|
||||
|
||||
def load_from_store(self):
|
||||
"""
|
||||
load_from_store() downloads a new checkpoint from the S3 data store when it is not available locally. It is used
|
||||
by the rollout workers when using Coach in distributed mode.
|
||||
"""
|
||||
try:
|
||||
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user