1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Add RedisDataStore (#295)

* GraphManager.set_session also sets self.sess

* make sure that GraphManager.fetch_from_worker uses training phase

* remove unnecessary phase setting in training worker

* reorganize rollout worker

* provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway

* allow dividing TrainingSteps and EnvironmentSteps

* add timestamps to the log

* added redis data store

* conflict merge fix
This commit is contained in:
Zach Dwiel
2019-08-28 14:15:58 -04:00
committed by shadiendrawis
parent 34e1c04f29
commit 7b0fccb041
18 changed files with 528 additions and 120 deletions

View File

@@ -17,6 +17,10 @@
from rl_coach.data_stores.nfs_data_store import NFSDataStore, NFSDataStoreParameters
from rl_coach.data_stores.s3_data_store import S3DataStore, S3DataStoreParameters
from rl_coach.data_stores.redis_data_store import (
RedisDataStore,
RedisDataStoreParameters,
)
from rl_coach.data_stores.data_store import DataStoreParameters
@@ -26,19 +30,39 @@ def get_data_store(params):
data_store = NFSDataStore(params)
elif type(params) == S3DataStoreParameters:
data_store = S3DataStore(params)
elif type(params) == RedisDataStoreParameters:
data_store = RedisDataStore(params)
else:
raise ValueError("invalid params type {}".format(type(params)))
return data_store
def construct_data_store_params(json: dict):
ds_params_instance = None
ds_params = DataStoreParameters(json['store_type'], json['orchestrator_type'], json['orchestrator_params'])
if json['store_type'] == 'nfs':
ds_params_instance = NFSDataStoreParameters(ds_params)
elif json['store_type'] == 's3':
ds_params_instance = S3DataStoreParameters(ds_params=ds_params,
end_point=json['end_point'],
bucket_name=json['bucket_name'],
checkpoint_dir=json['checkpoint_dir'],
expt_dir=json['expt_dir'])
ds_params = DataStoreParameters(
json["store_type"], json["orchestrator_type"], json["orchestrator_params"]
)
if json["store_type"] == "nfs":
ds_params_instance = NFSDataStoreParameters(
ds_params, checkpoint_dir=json["checkpoint_dir"]
)
elif json["store_type"] == "s3":
ds_params_instance = S3DataStoreParameters(
ds_params=ds_params,
end_point=json["end_point"],
bucket_name=json["bucket_name"],
checkpoint_dir=json["checkpoint_dir"],
expt_dir=json["expt_dir"],
)
elif json["store_type"] == "redis":
ds_params_instance = RedisDataStoreParameters(
ds_params,
redis_address=json["redis_address"],
redis_port=json["redis_port"],
redis_channel=json["redis_channel"],
)
else:
raise ValueError("store_type {} was found, expected 'nfs', 'redis' or 's3'.")
return ds_params_instance