1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-14 21:15:53 +01:00

Add RedisDataStore (#295)

* GraphManager.set_session also sets self.sess

* make sure that GraphManager.fetch_from_worker uses training phase

* remove unnecessary phase setting in training worker

* reorganize rollout worker

* provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway

* allow dividing TrainingSteps and EnvironmentSteps

* add timestamps to the log

* added redis data store

* conflict merge fix
This commit is contained in:
Zach Dwiel
2019-08-28 14:15:58 -04:00
committed by shadiendrawis
parent 34e1c04f29
commit 7b0fccb041
18 changed files with 528 additions and 120 deletions

View File

@@ -0,0 +1,96 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import os
from rl_coach.checkpoint import CheckpointStateReader
from rl_coach.data_stores.data_store import SyncFiles
class CheckpointDataStore(object):
"""
A DataStore which relies on the GraphManager check pointing methods to communicate policies.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.checkpoint_num = 0
def end_of_policies(self) -> bool:
"""
Returns true if no new policies will be added to this DataStore. This typically happens
because training has completed and is used to signal to the rollout workers to stop.
"""
return os.path.exists(
os.path.join(self.checkpoint_dir, SyncFiles.FINISHED.value)
)
def save_policy(self, graph_manager):
# TODO: it would be nice if restore_checkpoint accepted a checkpoint path as a
# parameter. as it is, one cannot distinguish between checkpoints used for coordination
# and checkpoints requested to a persistent disk for later use
graph_manager.task_parameters.checkpoint_restore_path = self.checkpoint_dir
graph_manager.save_checkpoint()
def load_policy(self, graph_manager, require_new_policy=True, timeout=None):
"""
Load a checkpoint into the specified graph_manager. The expectation here is that
save_to_store() and load_from_store() will synchronize a checkpoint directory with a
central repository such as NFS or S3.
:param graph_manager: the graph_manager to load the policy into
:param require_new_policy: if True, only load a policy if it hasn't been loaded in this
process yet before.
:param timeout: Will only try to load the policy once if timeout is None, otherwise will
retry for timeout seconds
"""
if self._new_policy_exists(require_new_policy, timeout):
# TODO: it would be nice if restore_checkpoint accepted a checkpoint path as a
# parameter. as it is, one cannot distinguish between checkpoints used for coordination
# and checkpoints requested to a persistent disk for later use
graph_manager.task_parameters.checkpoint_restore_path = self.checkpoint_dir
graph_manager.restore_checkpoint()
def _new_policy_exists(self, require_new_policy=True, timeout=None) -> bool:
"""
:param require_new_policy: if True, only load a policy if it hasn't been loaded in this
process yet before.
:param timeout: Will only try to load the policy once if timeout is None, otherwise will
retry for timeout seconds
"""
checkpoint_state_reader = CheckpointStateReader(
self.checkpoint_dir, checkpoint_state_optional=False
)
checkpoint = "first"
if timeout is None:
timeout = 0
timeout_ends = time.time() + timeout
while time.time() < timeout_ends or checkpoint == "first":
if self.end_of_policies():
return False
self.load_from_store()
checkpoint = checkpoint_state_reader.get_latest()
if checkpoint is not None:
if not require_new_policy or checkpoint.num > self.checkpoint_num:
self.checkpoint_num = checkpoint.num
return True
raise ValueError(
"Waited for {timeout} seconds, but no first policy was received.".format(
timeout=timeout
)
)

View File

@@ -26,23 +26,45 @@ class DataStoreParameters(object):
class DataStore(object):
"""
DataStores are used primarily to synchronize policies between training workers and rollout
workers. In the case of the S3DataStore, it is also being used to explicitly log artifacts such
as videos and logs into s3 for users to look at later. Artifact logging should be moved into a
separate instance of the DataStore class, or a different class altogether. It is possible that
users might be interested in logging artifacts through s3, but coordinating communication of
policies using something else like redis.
"""
def __init__(self, params: DataStoreParameters):
pass
"""
The parameters provided in the constructor to a DataStore are expected to contain the
parameters necessary to serialize and deserialize this DataStore.
"""
raise NotImplementedError()
def deploy(self) -> bool:
pass
raise NotImplementedError()
def get_info(self):
pass
raise NotImplementedError()
def undeploy(self) -> bool:
pass
raise NotImplementedError()
def save_to_store(self):
pass
raise NotImplementedError()
def load_from_store(self):
pass
raise NotImplementedError()
def save_policy(self, graph_manager):
raise NotImplementedError()
def load_policy(self, graph_manager, timeout=-1):
raise NotImplementedError()
def end_of_policies(self) -> bool:
raise NotImplementedError()
def setup_checkpoint_dir(self, crd=None):
pass

View File

@@ -17,6 +17,10 @@
from rl_coach.data_stores.nfs_data_store import NFSDataStore, NFSDataStoreParameters
from rl_coach.data_stores.s3_data_store import S3DataStore, S3DataStoreParameters
from rl_coach.data_stores.redis_data_store import (
RedisDataStore,
RedisDataStoreParameters,
)
from rl_coach.data_stores.data_store import DataStoreParameters
@@ -26,19 +30,39 @@ def get_data_store(params):
data_store = NFSDataStore(params)
elif type(params) == S3DataStoreParameters:
data_store = S3DataStore(params)
elif type(params) == RedisDataStoreParameters:
data_store = RedisDataStore(params)
else:
raise ValueError("invalid params type {}".format(type(params)))
return data_store
def construct_data_store_params(json: dict):
ds_params_instance = None
ds_params = DataStoreParameters(json['store_type'], json['orchestrator_type'], json['orchestrator_params'])
if json['store_type'] == 'nfs':
ds_params_instance = NFSDataStoreParameters(ds_params)
elif json['store_type'] == 's3':
ds_params_instance = S3DataStoreParameters(ds_params=ds_params,
end_point=json['end_point'],
bucket_name=json['bucket_name'],
checkpoint_dir=json['checkpoint_dir'],
expt_dir=json['expt_dir'])
ds_params = DataStoreParameters(
json["store_type"], json["orchestrator_type"], json["orchestrator_params"]
)
if json["store_type"] == "nfs":
ds_params_instance = NFSDataStoreParameters(
ds_params, checkpoint_dir=json["checkpoint_dir"]
)
elif json["store_type"] == "s3":
ds_params_instance = S3DataStoreParameters(
ds_params=ds_params,
end_point=json["end_point"],
bucket_name=json["bucket_name"],
checkpoint_dir=json["checkpoint_dir"],
expt_dir=json["expt_dir"],
)
elif json["store_type"] == "redis":
ds_params_instance = RedisDataStoreParameters(
ds_params,
redis_address=json["redis_address"],
redis_port=json["redis_port"],
redis_channel=json["redis_channel"],
)
else:
raise ValueError("store_type {} was found, expected 'nfs', 'redis' or 's3'.")
return ds_params_instance

View File

@@ -17,15 +17,17 @@
import uuid
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore
class NFSDataStoreParameters(DataStoreParameters):
def __init__(self, ds_params, deployed=False, server=None, path=None):
def __init__(self, ds_params, deployed=False, server=None, path=None, checkpoint_dir: str=""):
super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params)
self.namespace = "default"
if "namespace" in ds_params.orchestrator_params:
self.namespace = ds_params.orchestrator_params["namespace"]
self.checkpoint_dir = checkpoint_dir
self.name = None
self.pvc_name = None
self.pv_name = None
@@ -38,7 +40,7 @@ class NFSDataStoreParameters(DataStoreParameters):
self.path = path
class NFSDataStore(DataStore):
class NFSDataStore(CheckpointDataStore):
"""
An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.
The policy checkpoints are written by the trainer and read by the rollout worker.

View File

@@ -0,0 +1,192 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import uuid
import redis
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
class RedisDataStoreParameters(DataStoreParameters):
def __init__(
self,
ds_params,
redis_address: str = "",
redis_port: int = 6379,
redis_channel: str = "data-store-channel-{}".format(uuid.uuid4()),
):
super().__init__(
ds_params.store_type,
ds_params.orchestrator_type,
ds_params.orchestrator_params,
)
self.redis_address = redis_address
self.redis_port = redis_port
self.redis_channel = redis_channel
class RedisDataStore(DataStore):
"""
This DataStore sends policies over redis pubsub and get/set.
Deployment
==========
It assumes that a redis server is already available. We make this assumption because during
multinode training at this time, redis is already used for communicating replay memories.
Communication
=============
A redis pubsub channel is used by the training worker to signal to the rollout workers that a
new policy is ready. When this occurs, a new policy is loaded from the redis key/value store
where key is the same as the pubsub channel. Originally, just the pubsub was used, but that
could result in a race condition where the master worker publishes the first policy and waits
for the rollout workers to submit all rollouts, while a delayed rollout worker waits for the
first policy since it subscribed to the channel after the initial policy was published.
"""
def __init__(self, params: RedisDataStoreParameters):
self.params = params
self.saver = None
self._end_of_policies = False
# NOTE: a connection is not attempted at this stage because the address and port are likely
# not available yet. This is because of how the kubernetes orchestrator works. At the time
# of parameter construction, the address and port are not yet known since they are copied
# out of the redis memory backend after it is deployed. One improvement would be to use
# two separate redis deployments independently, and let this class deploy its own redis.
def _connect(self):
"""
Connect to redis and subscribe to the pubsub channel
"""
self.redis_connection = redis.Redis(
self.params.redis_address, self.params.redis_port
)
self.pubsub = self.redis_connection.pubsub(ignore_subscribe_messages=True)
self.pubsub.subscribe(self.params.redis_channel)
self._end_of_policies = False
def deploy(self):
"""
For now, this data store does not handle its own deployment, it piggybacks off of the redis
memory backend
"""
return True
def undeploy(self):
"""
For now, this data store does not handle its own deployment, it piggybacks off of the redis
memory backend
"""
pass
def save_to_store(self):
"""
save_to_store and load_from_store are not used in the case where the data stored needs to
synchronize checkpoints saved to disk into a central file system, and not used here
"""
pass
def load_from_store(self):
"""
save_to_store and load_from_store are not used in the case where the data stored needs to
synchronize checkpoints saved to disk into a central file system, and not used here
"""
pass
def save_policy(self, graph_manager):
"""
Serialize the policy in graph_manager, set it as the latest policy and publish a new_policy
event
"""
if self.saver is None:
self.saver = GlobalVariableSaver()
# TODO: only subscribe if this data store is being used to publish policies
self._connect()
self.pubsub.unsubscribe(self.params.redis_channel)
policy_string = self.saver.to_string(graph_manager.sess)
self.redis_connection.set(self.params.redis_channel, policy_string)
self.redis_connection.publish(self.params.redis_channel, "new_policy")
def _load_policy(self, graph_manager) -> bool:
"""
Get the most recent policy from redis and loaded into the graph_manager
"""
policy_string = self.redis_connection.get(self.params.redis_channel)
if policy_string is None:
return False
self.saver.from_string(graph_manager.sess, policy_string)
return True
def load_policy(self, graph_manager, require_new_policy=True, timeout=0):
"""
:param graph_manager: the graph_manager to load the policy into
:param require_new_policy: if True, only load a policy if it hasn't been loaded in this
process yet before.
:param timeout: Will only try to load the policy once if timeout is None, otherwise will
retry for timeout seconds
"""
if self.saver is None:
# the GlobalVariableSaver needs to be instantiated after the graph is created. For now,
# it can be instantiated here, but it might be nicer to have a more explicit
# on_graph_creation_end callback or similar to put it in
self.saver = GlobalVariableSaver()
self._connect()
if not require_new_policy:
# try just loading whatever policy is available most recently
if self._load_policy(graph_manager):
return
message = "first"
timeout_ends = time.time() + timeout
while time.time() < timeout_ends or message == "first":
message = self.pubsub.get_message()
if message and message["type"] == "message":
if message["data"] == b"end_of_policies":
self._end_of_policies = True
return
elif message["data"] == b"new_policy":
if self._load_policy(graph_manager):
return
else:
raise ValueError("'new_policy' message was sent, but no policy was found.")
time.sleep(1.0)
if require_new_policy:
raise ValueError(
"Waited for {timeout} seconds on channel {channel}, but no first policy was received.".format(
timeout=timeout, channel=self.params.redis_channel
)
)
def end_of_policies(self) -> bool:
"""
This is used by the rollout workers to detect a message from the training worker signaling
that training is complete.
"""
return self._end_of_policies

View File

@@ -15,7 +15,8 @@
#
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore
from minio import Minio
from minio.error import ResponseError
from configparser import ConfigParser, Error
@@ -39,7 +40,7 @@ class S3DataStoreParameters(DataStoreParameters):
self.expt_dir = expt_dir
class S3DataStore(DataStore):
class S3DataStore(CheckpointDataStore):
"""
An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.
The policy checkpoints are written by the trainer and read by the rollout worker.