1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
Files
coach/rl_coach/data_stores/data_store_impl.py
Zach Dwiel 7b0fccb041 Add RedisDataStore (#295)
* GraphManager.set_session also sets self.sess

* make sure that GraphManager.fetch_from_worker uses training phase

* remove unnecessary phase setting in training worker

* reorganize rollout worker

* provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway

* allow dividing TrainingSteps and EnvironmentSteps

* add timestamps to the log

* added redis data store

* conflict merge fix
2019-08-28 21:15:58 +03:00

69 lines
2.4 KiB
Python

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from rl_coach.data_stores.nfs_data_store import NFSDataStore, NFSDataStoreParameters
from rl_coach.data_stores.s3_data_store import S3DataStore, S3DataStoreParameters
from rl_coach.data_stores.redis_data_store import (
RedisDataStore,
RedisDataStoreParameters,
)
from rl_coach.data_stores.data_store import DataStoreParameters
def get_data_store(params):
data_store = None
if type(params) == NFSDataStoreParameters:
data_store = NFSDataStore(params)
elif type(params) == S3DataStoreParameters:
data_store = S3DataStore(params)
elif type(params) == RedisDataStoreParameters:
data_store = RedisDataStore(params)
else:
raise ValueError("invalid params type {}".format(type(params)))
return data_store
def construct_data_store_params(json: dict):
ds_params_instance = None
ds_params = DataStoreParameters(
json["store_type"], json["orchestrator_type"], json["orchestrator_params"]
)
if json["store_type"] == "nfs":
ds_params_instance = NFSDataStoreParameters(
ds_params, checkpoint_dir=json["checkpoint_dir"]
)
elif json["store_type"] == "s3":
ds_params_instance = S3DataStoreParameters(
ds_params=ds_params,
end_point=json["end_point"],
bucket_name=json["bucket_name"],
checkpoint_dir=json["checkpoint_dir"],
expt_dir=json["expt_dir"],
)
elif json["store_type"] == "redis":
ds_params_instance = RedisDataStoreParameters(
ds_params,
redis_address=json["redis_address"],
redis_port=json["redis_port"],
redis_channel=json["redis_channel"],
)
else:
raise ValueError("store_type {} was found, expected 'nfs', 'redis' or 's3'.")
return ds_params_instance