1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Add RedisDataStore (#295)

* GraphManager.set_session also sets self.sess

* make sure that GraphManager.fetch_from_worker uses training phase

* remove unnecessary phase setting in training worker

* reorganize rollout worker

* provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway

* allow dividing TrainingSteps and EnvironmentSteps

* add timestamps to the log

* added redis data store

* conflict merge fix
This commit is contained in:
Zach Dwiel
2019-08-28 14:15:58 -04:00
committed by shadiendrawis
parent 34e1c04f29
commit 7b0fccb041
18 changed files with 528 additions and 120 deletions

View File

@@ -26,23 +26,45 @@ class DataStoreParameters(object):
class DataStore(object):
"""
DataStores are used primarily to synchronize policies between training workers and rollout
workers. In the case of the S3DataStore, it is also being used to explicitly log artifacts such
as videos and logs into s3 for users to look at later. Artifact logging should be moved into a
separate instance of the DataStore class, or a different class altogether. It is possible that
users might be interested in logging artifacts through s3, but coordinating communication of
policies using something else like redis.
"""
def __init__(self, params: DataStoreParameters):
pass
"""
The parameters provided in the constructor to a DataStore are expected to contain the
parameters necessary to serialize and deserialize this DataStore.
"""
raise NotImplementedError()
def deploy(self) -> bool:
pass
raise NotImplementedError()
def get_info(self):
pass
raise NotImplementedError()
def undeploy(self) -> bool:
pass
raise NotImplementedError()
def save_to_store(self):
pass
raise NotImplementedError()
def load_from_store(self):
pass
raise NotImplementedError()
def save_policy(self, graph_manager):
raise NotImplementedError()
def load_policy(self, graph_manager, timeout=-1):
raise NotImplementedError()
def end_of_policies(self) -> bool:
raise NotImplementedError()
def setup_checkpoint_dir(self, crd=None):
pass