mirror of
https://github.com/gryf/coach.git
synced 2026-02-14 21:15:53 +01:00
Add RedisDataStore (#295)
* GraphManager.set_session also sets self.sess * make sure that GraphManager.fetch_from_worker uses training phase * remove unnecessary phase setting in training worker * reorganize rollout worker * provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway * allow dividing TrainingSteps and EnvironmentSteps * add timestamps to the log * added redis data store * conflict merge fix
This commit is contained in:
committed by
shadiendrawis
parent
34e1c04f29
commit
7b0fccb041
96
rl_coach/data_stores/checkpoint_data_store.py
Normal file
96
rl_coach/data_stores/checkpoint_data_store.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import os
|
||||
|
||||
from rl_coach.checkpoint import CheckpointStateReader
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
class CheckpointDataStore(object):
|
||||
"""
|
||||
A DataStore which relies on the GraphManager check pointing methods to communicate policies.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.checkpoint_num = 0
|
||||
|
||||
def end_of_policies(self) -> bool:
|
||||
"""
|
||||
Returns true if no new policies will be added to this DataStore. This typically happens
|
||||
because training has completed and is used to signal to the rollout workers to stop.
|
||||
"""
|
||||
return os.path.exists(
|
||||
os.path.join(self.checkpoint_dir, SyncFiles.FINISHED.value)
|
||||
)
|
||||
|
||||
def save_policy(self, graph_manager):
|
||||
# TODO: it would be nice if restore_checkpoint accepted a checkpoint path as a
|
||||
# parameter. as it is, one cannot distinguish between checkpoints used for coordination
|
||||
# and checkpoints requested to a persistent disk for later use
|
||||
graph_manager.task_parameters.checkpoint_restore_path = self.checkpoint_dir
|
||||
graph_manager.save_checkpoint()
|
||||
|
||||
def load_policy(self, graph_manager, require_new_policy=True, timeout=None):
|
||||
"""
|
||||
Load a checkpoint into the specified graph_manager. The expectation here is that
|
||||
save_to_store() and load_from_store() will synchronize a checkpoint directory with a
|
||||
central repository such as NFS or S3.
|
||||
|
||||
:param graph_manager: the graph_manager to load the policy into
|
||||
:param require_new_policy: if True, only load a policy if it hasn't been loaded in this
|
||||
process yet before.
|
||||
:param timeout: Will only try to load the policy once if timeout is None, otherwise will
|
||||
retry for timeout seconds
|
||||
"""
|
||||
if self._new_policy_exists(require_new_policy, timeout):
|
||||
# TODO: it would be nice if restore_checkpoint accepted a checkpoint path as a
|
||||
# parameter. as it is, one cannot distinguish between checkpoints used for coordination
|
||||
# and checkpoints requested to a persistent disk for later use
|
||||
graph_manager.task_parameters.checkpoint_restore_path = self.checkpoint_dir
|
||||
graph_manager.restore_checkpoint()
|
||||
|
||||
def _new_policy_exists(self, require_new_policy=True, timeout=None) -> bool:
|
||||
"""
|
||||
:param require_new_policy: if True, only load a policy if it hasn't been loaded in this
|
||||
process yet before.
|
||||
:param timeout: Will only try to load the policy once if timeout is None, otherwise will
|
||||
retry for timeout seconds
|
||||
"""
|
||||
checkpoint_state_reader = CheckpointStateReader(
|
||||
self.checkpoint_dir, checkpoint_state_optional=False
|
||||
)
|
||||
checkpoint = "first"
|
||||
if timeout is None:
|
||||
timeout = 0
|
||||
timeout_ends = time.time() + timeout
|
||||
while time.time() < timeout_ends or checkpoint == "first":
|
||||
if self.end_of_policies():
|
||||
return False
|
||||
|
||||
self.load_from_store()
|
||||
|
||||
checkpoint = checkpoint_state_reader.get_latest()
|
||||
if checkpoint is not None:
|
||||
if not require_new_policy or checkpoint.num > self.checkpoint_num:
|
||||
self.checkpoint_num = checkpoint.num
|
||||
return True
|
||||
|
||||
raise ValueError(
|
||||
"Waited for {timeout} seconds, but no first policy was received.".format(
|
||||
timeout=timeout
|
||||
)
|
||||
)
|
||||
@@ -26,23 +26,45 @@ class DataStoreParameters(object):
|
||||
|
||||
|
||||
class DataStore(object):
|
||||
"""
|
||||
DataStores are used primarily to synchronize policies between training workers and rollout
|
||||
workers. In the case of the S3DataStore, it is also being used to explicitly log artifacts such
|
||||
as videos and logs into s3 for users to look at later. Artifact logging should be moved into a
|
||||
separate instance of the DataStore class, or a different class altogether. It is possible that
|
||||
users might be interested in logging artifacts through s3, but coordinating communication of
|
||||
policies using something else like redis.
|
||||
"""
|
||||
|
||||
def __init__(self, params: DataStoreParameters):
|
||||
pass
|
||||
"""
|
||||
The parameters provided in the constructor to a DataStore are expected to contain the
|
||||
parameters necessary to serialize and deserialize this DataStore.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def deploy(self) -> bool:
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_info(self):
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def undeploy(self) -> bool:
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_to_store(self):
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_from_store(self):
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_policy(self, graph_manager):
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_policy(self, graph_manager, timeout=-1):
|
||||
raise NotImplementedError()
|
||||
|
||||
def end_of_policies(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def setup_checkpoint_dir(self, crd=None):
|
||||
pass
|
||||
|
||||
@@ -17,6 +17,10 @@
|
||||
|
||||
from rl_coach.data_stores.nfs_data_store import NFSDataStore, NFSDataStoreParameters
|
||||
from rl_coach.data_stores.s3_data_store import S3DataStore, S3DataStoreParameters
|
||||
from rl_coach.data_stores.redis_data_store import (
|
||||
RedisDataStore,
|
||||
RedisDataStoreParameters,
|
||||
)
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
|
||||
|
||||
@@ -26,19 +30,39 @@ def get_data_store(params):
|
||||
data_store = NFSDataStore(params)
|
||||
elif type(params) == S3DataStoreParameters:
|
||||
data_store = S3DataStore(params)
|
||||
elif type(params) == RedisDataStoreParameters:
|
||||
data_store = RedisDataStore(params)
|
||||
else:
|
||||
raise ValueError("invalid params type {}".format(type(params)))
|
||||
|
||||
return data_store
|
||||
|
||||
|
||||
def construct_data_store_params(json: dict):
|
||||
ds_params_instance = None
|
||||
ds_params = DataStoreParameters(json['store_type'], json['orchestrator_type'], json['orchestrator_params'])
|
||||
if json['store_type'] == 'nfs':
|
||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||
elif json['store_type'] == 's3':
|
||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params,
|
||||
end_point=json['end_point'],
|
||||
bucket_name=json['bucket_name'],
|
||||
checkpoint_dir=json['checkpoint_dir'],
|
||||
expt_dir=json['expt_dir'])
|
||||
ds_params = DataStoreParameters(
|
||||
json["store_type"], json["orchestrator_type"], json["orchestrator_params"]
|
||||
)
|
||||
if json["store_type"] == "nfs":
|
||||
ds_params_instance = NFSDataStoreParameters(
|
||||
ds_params, checkpoint_dir=json["checkpoint_dir"]
|
||||
)
|
||||
elif json["store_type"] == "s3":
|
||||
ds_params_instance = S3DataStoreParameters(
|
||||
ds_params=ds_params,
|
||||
end_point=json["end_point"],
|
||||
bucket_name=json["bucket_name"],
|
||||
checkpoint_dir=json["checkpoint_dir"],
|
||||
expt_dir=json["expt_dir"],
|
||||
)
|
||||
elif json["store_type"] == "redis":
|
||||
ds_params_instance = RedisDataStoreParameters(
|
||||
ds_params,
|
||||
redis_address=json["redis_address"],
|
||||
redis_port=json["redis_port"],
|
||||
redis_channel=json["redis_channel"],
|
||||
)
|
||||
else:
|
||||
raise ValueError("store_type {} was found, expected 'nfs', 'redis' or 's3'.")
|
||||
|
||||
return ds_params_instance
|
||||
|
||||
@@ -17,15 +17,17 @@
|
||||
|
||||
import uuid
|
||||
|
||||
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore
|
||||
|
||||
|
||||
class NFSDataStoreParameters(DataStoreParameters):
|
||||
def __init__(self, ds_params, deployed=False, server=None, path=None):
|
||||
def __init__(self, ds_params, deployed=False, server=None, path=None, checkpoint_dir: str=""):
|
||||
super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params)
|
||||
self.namespace = "default"
|
||||
if "namespace" in ds_params.orchestrator_params:
|
||||
self.namespace = ds_params.orchestrator_params["namespace"]
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.name = None
|
||||
self.pvc_name = None
|
||||
self.pv_name = None
|
||||
@@ -38,7 +40,7 @@ class NFSDataStoreParameters(DataStoreParameters):
|
||||
self.path = path
|
||||
|
||||
|
||||
class NFSDataStore(DataStore):
|
||||
class NFSDataStore(CheckpointDataStore):
|
||||
"""
|
||||
An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.
|
||||
The policy checkpoints are written by the trainer and read by the rollout worker.
|
||||
|
||||
192
rl_coach/data_stores/redis_data_store.py
Normal file
192
rl_coach/data_stores/redis_data_store.py
Normal file
@@ -0,0 +1,192 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import redis
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
|
||||
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
||||
|
||||
|
||||
class RedisDataStoreParameters(DataStoreParameters):
|
||||
def __init__(
|
||||
self,
|
||||
ds_params,
|
||||
redis_address: str = "",
|
||||
redis_port: int = 6379,
|
||||
redis_channel: str = "data-store-channel-{}".format(uuid.uuid4()),
|
||||
):
|
||||
super().__init__(
|
||||
ds_params.store_type,
|
||||
ds_params.orchestrator_type,
|
||||
ds_params.orchestrator_params,
|
||||
)
|
||||
self.redis_address = redis_address
|
||||
self.redis_port = redis_port
|
||||
self.redis_channel = redis_channel
|
||||
|
||||
|
||||
class RedisDataStore(DataStore):
|
||||
"""
|
||||
This DataStore sends policies over redis pubsub and get/set.
|
||||
|
||||
Deployment
|
||||
==========
|
||||
It assumes that a redis server is already available. We make this assumption because during
|
||||
multinode training at this time, redis is already used for communicating replay memories.
|
||||
|
||||
Communication
|
||||
=============
|
||||
|
||||
A redis pubsub channel is used by the training worker to signal to the rollout workers that a
|
||||
new policy is ready. When this occurs, a new policy is loaded from the redis key/value store
|
||||
where key is the same as the pubsub channel. Originally, just the pubsub was used, but that
|
||||
could result in a race condition where the master worker publishes the first policy and waits
|
||||
for the rollout workers to submit all rollouts, while a delayed rollout worker waits for the
|
||||
first policy since it subscribed to the channel after the initial policy was published.
|
||||
"""
|
||||
|
||||
def __init__(self, params: RedisDataStoreParameters):
|
||||
self.params = params
|
||||
self.saver = None
|
||||
self._end_of_policies = False
|
||||
|
||||
# NOTE: a connection is not attempted at this stage because the address and port are likely
|
||||
# not available yet. This is because of how the kubernetes orchestrator works. At the time
|
||||
# of parameter construction, the address and port are not yet known since they are copied
|
||||
# out of the redis memory backend after it is deployed. One improvement would be to use
|
||||
# two separate redis deployments independently, and let this class deploy its own redis.
|
||||
|
||||
def _connect(self):
|
||||
"""
|
||||
Connect to redis and subscribe to the pubsub channel
|
||||
"""
|
||||
self.redis_connection = redis.Redis(
|
||||
self.params.redis_address, self.params.redis_port
|
||||
)
|
||||
self.pubsub = self.redis_connection.pubsub(ignore_subscribe_messages=True)
|
||||
self.pubsub.subscribe(self.params.redis_channel)
|
||||
|
||||
self._end_of_policies = False
|
||||
|
||||
def deploy(self):
|
||||
"""
|
||||
For now, this data store does not handle its own deployment, it piggybacks off of the redis
|
||||
memory backend
|
||||
"""
|
||||
return True
|
||||
|
||||
def undeploy(self):
|
||||
"""
|
||||
For now, this data store does not handle its own deployment, it piggybacks off of the redis
|
||||
memory backend
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_to_store(self):
|
||||
"""
|
||||
save_to_store and load_from_store are not used in the case where the data stored needs to
|
||||
synchronize checkpoints saved to disk into a central file system, and not used here
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_from_store(self):
|
||||
"""
|
||||
save_to_store and load_from_store are not used in the case where the data stored needs to
|
||||
synchronize checkpoints saved to disk into a central file system, and not used here
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_policy(self, graph_manager):
|
||||
"""
|
||||
Serialize the policy in graph_manager, set it as the latest policy and publish a new_policy
|
||||
event
|
||||
"""
|
||||
if self.saver is None:
|
||||
self.saver = GlobalVariableSaver()
|
||||
|
||||
# TODO: only subscribe if this data store is being used to publish policies
|
||||
self._connect()
|
||||
self.pubsub.unsubscribe(self.params.redis_channel)
|
||||
|
||||
policy_string = self.saver.to_string(graph_manager.sess)
|
||||
self.redis_connection.set(self.params.redis_channel, policy_string)
|
||||
self.redis_connection.publish(self.params.redis_channel, "new_policy")
|
||||
|
||||
def _load_policy(self, graph_manager) -> bool:
|
||||
"""
|
||||
Get the most recent policy from redis and loaded into the graph_manager
|
||||
"""
|
||||
policy_string = self.redis_connection.get(self.params.redis_channel)
|
||||
if policy_string is None:
|
||||
return False
|
||||
|
||||
self.saver.from_string(graph_manager.sess, policy_string)
|
||||
return True
|
||||
|
||||
def load_policy(self, graph_manager, require_new_policy=True, timeout=0):
|
||||
"""
|
||||
:param graph_manager: the graph_manager to load the policy into
|
||||
:param require_new_policy: if True, only load a policy if it hasn't been loaded in this
|
||||
process yet before.
|
||||
:param timeout: Will only try to load the policy once if timeout is None, otherwise will
|
||||
retry for timeout seconds
|
||||
"""
|
||||
if self.saver is None:
|
||||
# the GlobalVariableSaver needs to be instantiated after the graph is created. For now,
|
||||
# it can be instantiated here, but it might be nicer to have a more explicit
|
||||
# on_graph_creation_end callback or similar to put it in
|
||||
self.saver = GlobalVariableSaver()
|
||||
self._connect()
|
||||
|
||||
if not require_new_policy:
|
||||
# try just loading whatever policy is available most recently
|
||||
if self._load_policy(graph_manager):
|
||||
return
|
||||
|
||||
message = "first"
|
||||
timeout_ends = time.time() + timeout
|
||||
while time.time() < timeout_ends or message == "first":
|
||||
message = self.pubsub.get_message()
|
||||
|
||||
if message and message["type"] == "message":
|
||||
if message["data"] == b"end_of_policies":
|
||||
self._end_of_policies = True
|
||||
return
|
||||
elif message["data"] == b"new_policy":
|
||||
if self._load_policy(graph_manager):
|
||||
return
|
||||
else:
|
||||
raise ValueError("'new_policy' message was sent, but no policy was found.")
|
||||
|
||||
time.sleep(1.0)
|
||||
|
||||
if require_new_policy:
|
||||
raise ValueError(
|
||||
"Waited for {timeout} seconds on channel {channel}, but no first policy was received.".format(
|
||||
timeout=timeout, channel=self.params.redis_channel
|
||||
)
|
||||
)
|
||||
|
||||
def end_of_policies(self) -> bool:
|
||||
"""
|
||||
This is used by the rollout workers to detect a message from the training worker signaling
|
||||
that training is complete.
|
||||
"""
|
||||
return self._end_of_policies
|
||||
@@ -15,7 +15,8 @@
|
||||
#
|
||||
|
||||
|
||||
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore
|
||||
from minio import Minio
|
||||
from minio.error import ResponseError
|
||||
from configparser import ConfigParser, Error
|
||||
@@ -39,7 +40,7 @@ class S3DataStoreParameters(DataStoreParameters):
|
||||
self.expt_dir = expt_dir
|
||||
|
||||
|
||||
class S3DataStore(DataStore):
|
||||
class S3DataStore(CheckpointDataStore):
|
||||
"""
|
||||
An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.
|
||||
The policy checkpoints are written by the trainer and read by the rollout worker.
|
||||
|
||||
Reference in New Issue
Block a user