mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Add RedisDataStore (#295)
* GraphManager.set_session also sets self.sess * make sure that GraphManager.fetch_from_worker uses training phase * remove unnecessary phase setting in training worker * reorganize rollout worker * provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway * allow dividing TrainingSteps and EnvironmentSteps * add timestamps to the log * added redis data store * conflict merge fix
This commit is contained in:
committed by
shadiendrawis
parent
34e1c04f29
commit
7b0fccb041
4
distributed-coach.config
Normal file
4
distributed-coach.config
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
[coach]
|
||||||
|
image = amr-registry.caas.intel.com/aipg/coach
|
||||||
|
memory_backend = redispubsub
|
||||||
|
data_store = redis
|
||||||
@@ -1,7 +1,11 @@
|
|||||||
# REGISTRY=gcr.io
|
REGISTRY=gcr.io
|
||||||
REGISTRY=docker.io
|
REGISTRY=docker.io
|
||||||
ORGANIZATION=nervana
|
ORGANIZATION=nervana
|
||||||
IMAGE=coach
|
|
||||||
|
# REGISTRY=amr-registry.caas.intel.com
|
||||||
|
# ORGANIZATION=aipg
|
||||||
|
# IMAGE=coach
|
||||||
|
|
||||||
CONTEXT = $(realpath ..)
|
CONTEXT = $(realpath ..)
|
||||||
|
|
||||||
BUILD_ARGUMENTS=
|
BUILD_ARGUMENTS=
|
||||||
@@ -111,16 +115,17 @@ bootstrap_kubernetes: build push
|
|||||||
kubectl run -i --tty --attach --image=${REGISTRY}/${IMAGE} --restart=Never distributed-coach -- python3 rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np /
|
kubectl run -i --tty --attach --image=${REGISTRY}/${IMAGE} --restart=Never distributed-coach -- python3 rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np /
|
||||||
|
|
||||||
stop_kubernetes:
|
stop_kubernetes:
|
||||||
kubectl delete service --ignore-not-found redis-service
|
kubectl get deployments | grep redis-server | awk "{print $$1}" | xargs kubectl delete deployments --ignore-not-found | true
|
||||||
kubectl delete pv --ignore-not-found nfs-checkpoint-pv
|
kubectl get services | grep redis-service | awk "{print $$1}" | xargs kubectl delete services --ignore-not-found | true
|
||||||
kubectl delete pvc --ignore-not-found nfs-checkpoint-pvc
|
kubectl get jobs | grep train | awk "{print $$1}" | xargs kubectl delete jobs --ignore-not-found | true
|
||||||
kubectl delete deployment --ignore-not-found redis-server
|
kubectl get jobs | grep worker | awk "{print $$1}" | xargs kubectl delete jobs --ignore-not-found | true
|
||||||
kubectl get jobs | grep train | awk "{print $\1}" | xargs kubectl delete jobs
|
|
||||||
kubectl get jobs | grep worker | awk "{print $\1}" | xargs kubectl delete jobs
|
|
||||||
|
|
||||||
kubernetes: stop_kubernetes
|
kubernetes: stop_kubernetes
|
||||||
python3 ${CONTEXT}/rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np /
|
python3 ${CONTEXT}/rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np /
|
||||||
|
|
||||||
|
distributed: build push stop_kubernetes
|
||||||
|
python3 ${CONTEXT}/rl_coach/coach.py -p Mujoco_PPO -lvl humanoid --distributed_coach --distributed_coach_config_path ${CONTEXT}/distributed-coach.config -e stop_asking --num_workers 8
|
||||||
|
|
||||||
push: build
|
push: build
|
||||||
${DOCKER} tag ${IMAGE} ${REGISTRY}/${ORGANIZATION}/${IMAGE}
|
${DOCKER} tag ${IMAGE} ${REGISTRY}/${ORGANIZATION}/${IMAGE}
|
||||||
${DOCKER} push ${REGISTRY}/${ORGANIZATION}/${IMAGE}
|
${DOCKER} push ${REGISTRY}/${ORGANIZATION}/${IMAGE}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from rl_coach.saver import Saver
|
|||||||
|
|
||||||
|
|
||||||
class GlobalVariableSaver(Saver):
|
class GlobalVariableSaver(Saver):
|
||||||
def __init__(self, name):
|
def __init__(self, name=""):
|
||||||
self._names = [name]
|
self._names = [name]
|
||||||
# if graph is finalized, savers must have already already been added. This happens
|
# if graph is finalized, savers must have already already been added. This happens
|
||||||
# in the case of a MonitoredSession
|
# in the case of a MonitoredSession
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from rl_coach.memories.backend.memory_impl import construct_memory_params
|
|||||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||||
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
||||||
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
||||||
|
from rl_coach.data_stores.redis_data_store import RedisDataStoreParameters
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||||
from rl_coach.training_worker import training_worker
|
from rl_coach.training_worker import training_worker
|
||||||
from rl_coach.rollout_worker import rollout_worker
|
from rl_coach.rollout_worker import rollout_worker
|
||||||
@@ -97,29 +98,25 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
|||||||
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
|
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
|
||||||
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))
|
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))
|
||||||
|
|
||||||
|
data_store = None
|
||||||
data_store_params = None
|
data_store_params = None
|
||||||
if args.data_store_params:
|
if args.data_store_params:
|
||||||
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
|
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
|
||||||
data_store_params.expt_dir = args.experiment_path
|
data_store_params.expt_dir = args.experiment_path
|
||||||
data_store_params.checkpoint_dir = ckpt_inside_container
|
data_store_params.checkpoint_dir = ckpt_inside_container
|
||||||
graph_manager.data_store_params = data_store_params
|
graph_manager.data_store_params = data_store_params
|
||||||
|
|
||||||
data_store = None
|
|
||||||
if args.data_store_params:
|
|
||||||
data_store = get_data_store(data_store_params)
|
data_store = get_data_store(data_store_params)
|
||||||
|
|
||||||
if args.distributed_coach_run_type == RunType.TRAINER:
|
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||||
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||||
training_worker(
|
training_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
task_parameters=task_parameters,
|
|
||||||
data_store=data_store,
|
data_store=data_store,
|
||||||
|
task_parameters=task_parameters,
|
||||||
is_multi_node_test=args.is_multi_node_test
|
is_multi_node_test=args.is_multi_node_test
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||||
task_parameters.checkpoint_restore_path = ckpt_inside_container
|
|
||||||
|
|
||||||
rollout_worker(
|
rollout_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
data_store=data_store,
|
data_store=data_store,
|
||||||
@@ -162,6 +159,11 @@ def handle_distributed_coach_orchestrator(args):
|
|||||||
elif args.data_store == "nfs":
|
elif args.data_store == "nfs":
|
||||||
ds_params = DataStoreParameters("nfs", "kubernetes", "")
|
ds_params = DataStoreParameters("nfs", "kubernetes", "")
|
||||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||||
|
elif args.data_store == "redis":
|
||||||
|
ds_params = DataStoreParameters("redis", "kubernetes", "")
|
||||||
|
ds_params_instance = RedisDataStoreParameters(ds_params)
|
||||||
|
else:
|
||||||
|
raise ValueError("data_store {} found. Expected 's3' or 'nfs'".format(args.data_store))
|
||||||
|
|
||||||
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
|
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
|
||||||
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
|
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
|
||||||
@@ -375,7 +377,7 @@ class CoachLauncher(object):
|
|||||||
if args.image == '':
|
if args.image == '':
|
||||||
screen.error("Image cannot be empty.")
|
screen.error("Image cannot be empty.")
|
||||||
|
|
||||||
data_store_choices = ['s3', 'nfs']
|
data_store_choices = ['s3', 'nfs', 'redis']
|
||||||
if args.data_store not in data_store_choices:
|
if args.data_store not in data_store_choices:
|
||||||
screen.warning("{} data store is unsupported.".format(args.data_store))
|
screen.warning("{} data store is unsupported.".format(args.data_store))
|
||||||
screen.error("Supported data stores are {}.".format(data_store_choices))
|
screen.error("Supported data stores are {}.".format(data_store_choices))
|
||||||
|
|||||||
@@ -115,6 +115,12 @@ class TrainingSteps(StepMethod):
|
|||||||
def __init__(self, num_steps):
|
def __init__(self, num_steps):
|
||||||
super().__init__(num_steps)
|
super().__init__(num_steps)
|
||||||
|
|
||||||
|
def __truediv__(self, other):
|
||||||
|
if isinstance(other, EnvironmentSteps):
|
||||||
|
return math.ceil(self.num_steps / other.num_steps)
|
||||||
|
else:
|
||||||
|
super().__truediv__(self, other)
|
||||||
|
|
||||||
|
|
||||||
class Time(StepMethod):
|
class Time(StepMethod):
|
||||||
def __init__(self, num_steps):
|
def __init__(self, num_steps):
|
||||||
|
|||||||
96
rl_coach/data_stores/checkpoint_data_store.py
Normal file
96
rl_coach/data_stores/checkpoint_data_store.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
from rl_coach.checkpoint import CheckpointStateReader
|
||||||
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointDataStore(object):
|
||||||
|
"""
|
||||||
|
A DataStore which relies on the GraphManager check pointing methods to communicate policies.
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.checkpoint_num = 0
|
||||||
|
|
||||||
|
def end_of_policies(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns true if no new policies will be added to this DataStore. This typically happens
|
||||||
|
because training has completed and is used to signal to the rollout workers to stop.
|
||||||
|
"""
|
||||||
|
return os.path.exists(
|
||||||
|
os.path.join(self.checkpoint_dir, SyncFiles.FINISHED.value)
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_policy(self, graph_manager):
|
||||||
|
# TODO: it would be nice if restore_checkpoint accepted a checkpoint path as a
|
||||||
|
# parameter. as it is, one cannot distinguish between checkpoints used for coordination
|
||||||
|
# and checkpoints requested to a persistent disk for later use
|
||||||
|
graph_manager.task_parameters.checkpoint_restore_path = self.checkpoint_dir
|
||||||
|
graph_manager.save_checkpoint()
|
||||||
|
|
||||||
|
def load_policy(self, graph_manager, require_new_policy=True, timeout=None):
|
||||||
|
"""
|
||||||
|
Load a checkpoint into the specified graph_manager. The expectation here is that
|
||||||
|
save_to_store() and load_from_store() will synchronize a checkpoint directory with a
|
||||||
|
central repository such as NFS or S3.
|
||||||
|
|
||||||
|
:param graph_manager: the graph_manager to load the policy into
|
||||||
|
:param require_new_policy: if True, only load a policy if it hasn't been loaded in this
|
||||||
|
process yet before.
|
||||||
|
:param timeout: Will only try to load the policy once if timeout is None, otherwise will
|
||||||
|
retry for timeout seconds
|
||||||
|
"""
|
||||||
|
if self._new_policy_exists(require_new_policy, timeout):
|
||||||
|
# TODO: it would be nice if restore_checkpoint accepted a checkpoint path as a
|
||||||
|
# parameter. as it is, one cannot distinguish between checkpoints used for coordination
|
||||||
|
# and checkpoints requested to a persistent disk for later use
|
||||||
|
graph_manager.task_parameters.checkpoint_restore_path = self.checkpoint_dir
|
||||||
|
graph_manager.restore_checkpoint()
|
||||||
|
|
||||||
|
def _new_policy_exists(self, require_new_policy=True, timeout=None) -> bool:
|
||||||
|
"""
|
||||||
|
:param require_new_policy: if True, only load a policy if it hasn't been loaded in this
|
||||||
|
process yet before.
|
||||||
|
:param timeout: Will only try to load the policy once if timeout is None, otherwise will
|
||||||
|
retry for timeout seconds
|
||||||
|
"""
|
||||||
|
checkpoint_state_reader = CheckpointStateReader(
|
||||||
|
self.checkpoint_dir, checkpoint_state_optional=False
|
||||||
|
)
|
||||||
|
checkpoint = "first"
|
||||||
|
if timeout is None:
|
||||||
|
timeout = 0
|
||||||
|
timeout_ends = time.time() + timeout
|
||||||
|
while time.time() < timeout_ends or checkpoint == "first":
|
||||||
|
if self.end_of_policies():
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.load_from_store()
|
||||||
|
|
||||||
|
checkpoint = checkpoint_state_reader.get_latest()
|
||||||
|
if checkpoint is not None:
|
||||||
|
if not require_new_policy or checkpoint.num > self.checkpoint_num:
|
||||||
|
self.checkpoint_num = checkpoint.num
|
||||||
|
return True
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Waited for {timeout} seconds, but no first policy was received.".format(
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
)
|
||||||
@@ -26,23 +26,45 @@ class DataStoreParameters(object):
|
|||||||
|
|
||||||
|
|
||||||
class DataStore(object):
|
class DataStore(object):
|
||||||
|
"""
|
||||||
|
DataStores are used primarily to synchronize policies between training workers and rollout
|
||||||
|
workers. In the case of the S3DataStore, it is also being used to explicitly log artifacts such
|
||||||
|
as videos and logs into s3 for users to look at later. Artifact logging should be moved into a
|
||||||
|
separate instance of the DataStore class, or a different class altogether. It is possible that
|
||||||
|
users might be interested in logging artifacts through s3, but coordinating communication of
|
||||||
|
policies using something else like redis.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, params: DataStoreParameters):
|
def __init__(self, params: DataStoreParameters):
|
||||||
pass
|
"""
|
||||||
|
The parameters provided in the constructor to a DataStore are expected to contain the
|
||||||
|
parameters necessary to serialize and deserialize this DataStore.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def deploy(self) -> bool:
|
def deploy(self) -> bool:
|
||||||
pass
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_info(self):
|
def get_info(self):
|
||||||
pass
|
raise NotImplementedError()
|
||||||
|
|
||||||
def undeploy(self) -> bool:
|
def undeploy(self) -> bool:
|
||||||
pass
|
raise NotImplementedError()
|
||||||
|
|
||||||
def save_to_store(self):
|
def save_to_store(self):
|
||||||
pass
|
raise NotImplementedError()
|
||||||
|
|
||||||
def load_from_store(self):
|
def load_from_store(self):
|
||||||
pass
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def save_policy(self, graph_manager):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def load_policy(self, graph_manager, timeout=-1):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def end_of_policies(self) -> bool:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def setup_checkpoint_dir(self, crd=None):
|
def setup_checkpoint_dir(self, crd=None):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -17,6 +17,10 @@
|
|||||||
|
|
||||||
from rl_coach.data_stores.nfs_data_store import NFSDataStore, NFSDataStoreParameters
|
from rl_coach.data_stores.nfs_data_store import NFSDataStore, NFSDataStoreParameters
|
||||||
from rl_coach.data_stores.s3_data_store import S3DataStore, S3DataStoreParameters
|
from rl_coach.data_stores.s3_data_store import S3DataStore, S3DataStoreParameters
|
||||||
|
from rl_coach.data_stores.redis_data_store import (
|
||||||
|
RedisDataStore,
|
||||||
|
RedisDataStoreParameters,
|
||||||
|
)
|
||||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||||
|
|
||||||
|
|
||||||
@@ -26,19 +30,39 @@ def get_data_store(params):
|
|||||||
data_store = NFSDataStore(params)
|
data_store = NFSDataStore(params)
|
||||||
elif type(params) == S3DataStoreParameters:
|
elif type(params) == S3DataStoreParameters:
|
||||||
data_store = S3DataStore(params)
|
data_store = S3DataStore(params)
|
||||||
|
elif type(params) == RedisDataStoreParameters:
|
||||||
|
data_store = RedisDataStore(params)
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid params type {}".format(type(params)))
|
||||||
|
|
||||||
return data_store
|
return data_store
|
||||||
|
|
||||||
|
|
||||||
def construct_data_store_params(json: dict):
|
def construct_data_store_params(json: dict):
|
||||||
ds_params_instance = None
|
ds_params_instance = None
|
||||||
ds_params = DataStoreParameters(json['store_type'], json['orchestrator_type'], json['orchestrator_params'])
|
ds_params = DataStoreParameters(
|
||||||
if json['store_type'] == 'nfs':
|
json["store_type"], json["orchestrator_type"], json["orchestrator_params"]
|
||||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
)
|
||||||
elif json['store_type'] == 's3':
|
if json["store_type"] == "nfs":
|
||||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params,
|
ds_params_instance = NFSDataStoreParameters(
|
||||||
end_point=json['end_point'],
|
ds_params, checkpoint_dir=json["checkpoint_dir"]
|
||||||
bucket_name=json['bucket_name'],
|
)
|
||||||
checkpoint_dir=json['checkpoint_dir'],
|
elif json["store_type"] == "s3":
|
||||||
expt_dir=json['expt_dir'])
|
ds_params_instance = S3DataStoreParameters(
|
||||||
|
ds_params=ds_params,
|
||||||
|
end_point=json["end_point"],
|
||||||
|
bucket_name=json["bucket_name"],
|
||||||
|
checkpoint_dir=json["checkpoint_dir"],
|
||||||
|
expt_dir=json["expt_dir"],
|
||||||
|
)
|
||||||
|
elif json["store_type"] == "redis":
|
||||||
|
ds_params_instance = RedisDataStoreParameters(
|
||||||
|
ds_params,
|
||||||
|
redis_address=json["redis_address"],
|
||||||
|
redis_port=json["redis_port"],
|
||||||
|
redis_channel=json["redis_channel"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("store_type {} was found, expected 'nfs', 'redis' or 's3'.")
|
||||||
|
|
||||||
return ds_params_instance
|
return ds_params_instance
|
||||||
|
|||||||
@@ -17,15 +17,17 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||||
|
from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore
|
||||||
|
|
||||||
|
|
||||||
class NFSDataStoreParameters(DataStoreParameters):
|
class NFSDataStoreParameters(DataStoreParameters):
|
||||||
def __init__(self, ds_params, deployed=False, server=None, path=None):
|
def __init__(self, ds_params, deployed=False, server=None, path=None, checkpoint_dir: str=""):
|
||||||
super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params)
|
super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params)
|
||||||
self.namespace = "default"
|
self.namespace = "default"
|
||||||
if "namespace" in ds_params.orchestrator_params:
|
if "namespace" in ds_params.orchestrator_params:
|
||||||
self.namespace = ds_params.orchestrator_params["namespace"]
|
self.namespace = ds_params.orchestrator_params["namespace"]
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.name = None
|
self.name = None
|
||||||
self.pvc_name = None
|
self.pvc_name = None
|
||||||
self.pv_name = None
|
self.pv_name = None
|
||||||
@@ -38,7 +40,7 @@ class NFSDataStoreParameters(DataStoreParameters):
|
|||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
|
|
||||||
class NFSDataStore(DataStore):
|
class NFSDataStore(CheckpointDataStore):
|
||||||
"""
|
"""
|
||||||
An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.
|
An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.
|
||||||
The policy checkpoints are written by the trainer and read by the rollout worker.
|
The policy checkpoints are written by the trainer and read by the rollout worker.
|
||||||
|
|||||||
192
rl_coach/data_stores/redis_data_store.py
Normal file
192
rl_coach/data_stores/redis_data_store.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
|
||||||
|
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDataStoreParameters(DataStoreParameters):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ds_params,
|
||||||
|
redis_address: str = "",
|
||||||
|
redis_port: int = 6379,
|
||||||
|
redis_channel: str = "data-store-channel-{}".format(uuid.uuid4()),
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
ds_params.store_type,
|
||||||
|
ds_params.orchestrator_type,
|
||||||
|
ds_params.orchestrator_params,
|
||||||
|
)
|
||||||
|
self.redis_address = redis_address
|
||||||
|
self.redis_port = redis_port
|
||||||
|
self.redis_channel = redis_channel
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDataStore(DataStore):
|
||||||
|
"""
|
||||||
|
This DataStore sends policies over redis pubsub and get/set.
|
||||||
|
|
||||||
|
Deployment
|
||||||
|
==========
|
||||||
|
It assumes that a redis server is already available. We make this assumption because during
|
||||||
|
multinode training at this time, redis is already used for communicating replay memories.
|
||||||
|
|
||||||
|
Communication
|
||||||
|
=============
|
||||||
|
|
||||||
|
A redis pubsub channel is used by the training worker to signal to the rollout workers that a
|
||||||
|
new policy is ready. When this occurs, a new policy is loaded from the redis key/value store
|
||||||
|
where key is the same as the pubsub channel. Originally, just the pubsub was used, but that
|
||||||
|
could result in a race condition where the master worker publishes the first policy and waits
|
||||||
|
for the rollout workers to submit all rollouts, while a delayed rollout worker waits for the
|
||||||
|
first policy since it subscribed to the channel after the initial policy was published.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params: RedisDataStoreParameters):
|
||||||
|
self.params = params
|
||||||
|
self.saver = None
|
||||||
|
self._end_of_policies = False
|
||||||
|
|
||||||
|
# NOTE: a connection is not attempted at this stage because the address and port are likely
|
||||||
|
# not available yet. This is because of how the kubernetes orchestrator works. At the time
|
||||||
|
# of parameter construction, the address and port are not yet known since they are copied
|
||||||
|
# out of the redis memory backend after it is deployed. One improvement would be to use
|
||||||
|
# two separate redis deployments independently, and let this class deploy its own redis.
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
"""
|
||||||
|
Connect to redis and subscribe to the pubsub channel
|
||||||
|
"""
|
||||||
|
self.redis_connection = redis.Redis(
|
||||||
|
self.params.redis_address, self.params.redis_port
|
||||||
|
)
|
||||||
|
self.pubsub = self.redis_connection.pubsub(ignore_subscribe_messages=True)
|
||||||
|
self.pubsub.subscribe(self.params.redis_channel)
|
||||||
|
|
||||||
|
self._end_of_policies = False
|
||||||
|
|
||||||
|
def deploy(self):
|
||||||
|
"""
|
||||||
|
For now, this data store does not handle its own deployment, it piggybacks off of the redis
|
||||||
|
memory backend
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def undeploy(self):
|
||||||
|
"""
|
||||||
|
For now, this data store does not handle its own deployment, it piggybacks off of the redis
|
||||||
|
memory backend
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_to_store(self):
|
||||||
|
"""
|
||||||
|
save_to_store and load_from_store are not used in the case where the data stored needs to
|
||||||
|
synchronize checkpoints saved to disk into a central file system, and not used here
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_from_store(self):
|
||||||
|
"""
|
||||||
|
save_to_store and load_from_store are not used in the case where the data stored needs to
|
||||||
|
synchronize checkpoints saved to disk into a central file system, and not used here
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_policy(self, graph_manager):
|
||||||
|
"""
|
||||||
|
Serialize the policy in graph_manager, set it as the latest policy and publish a new_policy
|
||||||
|
event
|
||||||
|
"""
|
||||||
|
if self.saver is None:
|
||||||
|
self.saver = GlobalVariableSaver()
|
||||||
|
|
||||||
|
# TODO: only subscribe if this data store is being used to publish policies
|
||||||
|
self._connect()
|
||||||
|
self.pubsub.unsubscribe(self.params.redis_channel)
|
||||||
|
|
||||||
|
policy_string = self.saver.to_string(graph_manager.sess)
|
||||||
|
self.redis_connection.set(self.params.redis_channel, policy_string)
|
||||||
|
self.redis_connection.publish(self.params.redis_channel, "new_policy")
|
||||||
|
|
||||||
|
def _load_policy(self, graph_manager) -> bool:
|
||||||
|
"""
|
||||||
|
Get the most recent policy from redis and loaded into the graph_manager
|
||||||
|
"""
|
||||||
|
policy_string = self.redis_connection.get(self.params.redis_channel)
|
||||||
|
if policy_string is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.saver.from_string(graph_manager.sess, policy_string)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_policy(self, graph_manager, require_new_policy=True, timeout=0):
|
||||||
|
"""
|
||||||
|
:param graph_manager: the graph_manager to load the policy into
|
||||||
|
:param require_new_policy: if True, only load a policy if it hasn't been loaded in this
|
||||||
|
process yet before.
|
||||||
|
:param timeout: Will only try to load the policy once if timeout is None, otherwise will
|
||||||
|
retry for timeout seconds
|
||||||
|
"""
|
||||||
|
if self.saver is None:
|
||||||
|
# the GlobalVariableSaver needs to be instantiated after the graph is created. For now,
|
||||||
|
# it can be instantiated here, but it might be nicer to have a more explicit
|
||||||
|
# on_graph_creation_end callback or similar to put it in
|
||||||
|
self.saver = GlobalVariableSaver()
|
||||||
|
self._connect()
|
||||||
|
|
||||||
|
if not require_new_policy:
|
||||||
|
# try just loading whatever policy is available most recently
|
||||||
|
if self._load_policy(graph_manager):
|
||||||
|
return
|
||||||
|
|
||||||
|
message = "first"
|
||||||
|
timeout_ends = time.time() + timeout
|
||||||
|
while time.time() < timeout_ends or message == "first":
|
||||||
|
message = self.pubsub.get_message()
|
||||||
|
|
||||||
|
if message and message["type"] == "message":
|
||||||
|
if message["data"] == b"end_of_policies":
|
||||||
|
self._end_of_policies = True
|
||||||
|
return
|
||||||
|
elif message["data"] == b"new_policy":
|
||||||
|
if self._load_policy(graph_manager):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise ValueError("'new_policy' message was sent, but no policy was found.")
|
||||||
|
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
|
if require_new_policy:
|
||||||
|
raise ValueError(
|
||||||
|
"Waited for {timeout} seconds on channel {channel}, but no first policy was received.".format(
|
||||||
|
timeout=timeout, channel=self.params.redis_channel
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def end_of_policies(self) -> bool:
|
||||||
|
"""
|
||||||
|
This is used by the rollout workers to detect a message from the training worker signaling
|
||||||
|
that training is complete.
|
||||||
|
"""
|
||||||
|
return self._end_of_policies
|
||||||
@@ -15,7 +15,8 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||||
|
from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
from minio.error import ResponseError
|
from minio.error import ResponseError
|
||||||
from configparser import ConfigParser, Error
|
from configparser import ConfigParser, Error
|
||||||
@@ -39,7 +40,7 @@ class S3DataStoreParameters(DataStoreParameters):
|
|||||||
self.expt_dir = expt_dir
|
self.expt_dir = expt_dir
|
||||||
|
|
||||||
|
|
||||||
class S3DataStore(DataStore):
|
class S3DataStore(CheckpointDataStore):
|
||||||
"""
|
"""
|
||||||
An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.
|
An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.
|
||||||
The policy checkpoints are written by the trainer and read by the rollout worker.
|
The policy checkpoints are written by the trainer and read by the rollout worker.
|
||||||
|
|||||||
@@ -232,18 +232,14 @@ class GraphManager(object):
|
|||||||
else:
|
else:
|
||||||
checkpoint_dir = task_parameters.checkpoint_save_dir
|
checkpoint_dir = task_parameters.checkpoint_save_dir
|
||||||
|
|
||||||
self.sess = create_monitored_session(target=task_parameters.worker_target,
|
self.set_session(create_monitored_session(target=task_parameters.worker_target,
|
||||||
task_index=task_parameters.task_index,
|
task_index=task_parameters.task_index,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
|
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
|
||||||
config=config)
|
config=config))
|
||||||
# set the session for all the modules
|
|
||||||
self.set_session(self.sess)
|
|
||||||
else:
|
else:
|
||||||
# regular session
|
# regular session
|
||||||
self.sess = tf.Session(config=config)
|
self.set_session(tf.Session(config=config))
|
||||||
# set the session for all the modules
|
|
||||||
self.set_session(self.sess)
|
|
||||||
|
|
||||||
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
|
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
|
||||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||||
@@ -366,6 +362,8 @@ class GraphManager(object):
|
|||||||
Set the deep learning framework session for all the modules in the graph
|
Set the deep learning framework session for all the modules in the graph
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
self.sess = sess
|
||||||
|
|
||||||
[manager.set_session(sess) for manager in self.level_managers]
|
[manager.set_session(sess) for manager in self.level_managers]
|
||||||
|
|
||||||
def heatup(self, steps: PlayingStepsType) -> None:
|
def heatup(self, steps: PlayingStepsType) -> None:
|
||||||
@@ -710,8 +708,9 @@ class GraphManager(object):
|
|||||||
|
|
||||||
def fetch_from_worker(self, num_consecutive_playing_steps=None):
|
def fetch_from_worker(self, num_consecutive_playing_steps=None):
|
||||||
if hasattr(self, 'memory_backend'):
|
if hasattr(self, 'memory_backend'):
|
||||||
for transition in self.memory_backend.fetch(num_consecutive_playing_steps):
|
with self.phase_context(RunPhase.TRAIN):
|
||||||
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
|
for transition in self.memory_backend.fetch(num_consecutive_playing_steps):
|
||||||
|
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
|
||||||
|
|
||||||
def setup_memory_backend(self) -> None:
|
def setup_memory_backend(self) -> None:
|
||||||
if hasattr(self.agent_params.memory, 'memory_backend_params'):
|
if hasattr(self.agent_params.memory, 'memory_backend_params'):
|
||||||
|
|||||||
@@ -87,13 +87,15 @@ class ScreenLogger(object):
|
|||||||
print(data)
|
print(data)
|
||||||
|
|
||||||
def log_dict(self, data, prefix=""):
|
def log_dict(self, data, prefix=""):
|
||||||
|
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S.%f') + ' '
|
||||||
if self._use_colors:
|
if self._use_colors:
|
||||||
str = "{}{}{} - ".format(Colors.PURPLE, prefix, Colors.END)
|
str = timestamp
|
||||||
|
str += "{}{}{} - ".format(Colors.PURPLE, prefix, Colors.END)
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
str += "{}{}: {}{} ".format(Colors.BLUE, k, Colors.END, v)
|
str += "{}{}: {}{} ".format(Colors.BLUE, k, Colors.END, v)
|
||||||
print(str)
|
print(str)
|
||||||
else:
|
else:
|
||||||
logentries = []
|
logentries = [timestamp]
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
logentries.append("{}={}".format(k, v))
|
logentries.append("{}={}".format(k, v))
|
||||||
logline = "{}> {}".format(prefix, ", ".join(logentries))
|
logline = "{}> {}".format(prefix, ", ".join(logentries))
|
||||||
|
|||||||
@@ -78,11 +78,18 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
"""
|
"""
|
||||||
if 'namespace' not in self.params.orchestrator_params:
|
if 'namespace' not in self.params.orchestrator_params:
|
||||||
self.params.orchestrator_params['namespace'] = "default"
|
self.params.orchestrator_params['namespace'] = "default"
|
||||||
from kubernetes import client
|
from kubernetes import client, config
|
||||||
|
|
||||||
container = client.V1Container(
|
container = client.V1Container(
|
||||||
name=self.redis_server_name,
|
name=self.redis_server_name,
|
||||||
image='redis:4-alpine',
|
image='redis:4-alpine',
|
||||||
|
resources=client.V1ResourceRequirements(
|
||||||
|
limits={
|
||||||
|
"cpu": "8",
|
||||||
|
"memory": "4Gi"
|
||||||
|
# "nvidia.com/gpu": "0",
|
||||||
|
}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
template = client.V1PodTemplateSpec(
|
template = client.V1PodTemplateSpec(
|
||||||
metadata=client.V1ObjectMeta(labels={'app': self.redis_server_name}),
|
metadata=client.V1ObjectMeta(labels={'app': self.redis_server_name}),
|
||||||
@@ -105,8 +112,10 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
spec=deployment_spec
|
spec=deployment_spec
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config.load_kube_config()
|
||||||
api_client = client.AppsV1Api()
|
api_client = client.AppsV1Api()
|
||||||
try:
|
try:
|
||||||
|
print(self.params.orchestrator_params)
|
||||||
api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'], deployment)
|
api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'], deployment)
|
||||||
except client.rest.ApiException as e:
|
except client.rest.ApiException as e:
|
||||||
print("Got exception: %s\n while creating redis-server", e)
|
print("Got exception: %s\n while creating redis-server", e)
|
||||||
|
|||||||
@@ -124,6 +124,11 @@ class Kubernetes(Deploy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.memory_backend.deploy()
|
self.memory_backend.deploy()
|
||||||
|
|
||||||
|
if self.params.data_store_params.store_type == "redis":
|
||||||
|
self.data_store.params.redis_address = self.memory_backend.params.redis_address
|
||||||
|
self.data_store.params.redis_port = self.memory_backend.params.redis_port
|
||||||
|
|
||||||
if not self.data_store.deploy():
|
if not self.data_store.deploy():
|
||||||
return False
|
return False
|
||||||
if self.params.data_store_params.store_type == "nfs":
|
if self.params.data_store_params.store_type == "nfs":
|
||||||
@@ -146,6 +151,8 @@ class Kubernetes(Deploy):
|
|||||||
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
||||||
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
||||||
|
|
||||||
|
# TODO: instead of defining each container and template spec from scratch, loaded default
|
||||||
|
# configuration and modify them as necessary depending on the store type
|
||||||
if self.params.data_store_params.store_type == "nfs":
|
if self.params.data_store_params.store_type == "nfs":
|
||||||
container = k8sclient.V1Container(
|
container = k8sclient.V1Container(
|
||||||
name=name,
|
name=name,
|
||||||
@@ -171,7 +178,7 @@ class Kubernetes(Deploy):
|
|||||||
restart_policy='Never'
|
restart_policy='Never'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
elif self.params.data_store_params.store_type == "s3":
|
||||||
container = k8sclient.V1Container(
|
container = k8sclient.V1Container(
|
||||||
name=name,
|
name=name,
|
||||||
image=trainer_params.image,
|
image=trainer_params.image,
|
||||||
@@ -190,6 +197,34 @@ class Kubernetes(Deploy):
|
|||||||
restart_policy='Never'
|
restart_policy='Never'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif self.params.data_store_params.store_type == "redis":
|
||||||
|
container = k8sclient.V1Container(
|
||||||
|
name=name,
|
||||||
|
image=trainer_params.image,
|
||||||
|
command=trainer_params.command,
|
||||||
|
args=trainer_params.arguments,
|
||||||
|
image_pull_policy='Always',
|
||||||
|
stdin=True,
|
||||||
|
tty=True,
|
||||||
|
resources=k8sclient.V1ResourceRequirements(
|
||||||
|
limits={
|
||||||
|
"cpu": "40",
|
||||||
|
"memory": "4Gi",
|
||||||
|
"nvidia.com/gpu": "1",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
template = k8sclient.V1PodTemplateSpec(
|
||||||
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
|
spec=k8sclient.V1PodSpec(
|
||||||
|
containers=[container],
|
||||||
|
restart_policy='Never'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("unexpected store_type {}. expected 's3', 'nfs', 'redis'".format(
|
||||||
|
self.params.data_store_params.store_type
|
||||||
|
))
|
||||||
|
|
||||||
job_spec = k8sclient.V1JobSpec(
|
job_spec = k8sclient.V1JobSpec(
|
||||||
completions=1,
|
completions=1,
|
||||||
@@ -221,12 +256,17 @@ class Kubernetes(Deploy):
|
|||||||
if not worker_params:
|
if not worker_params:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# At this point, the memory backend and data store have been deployed and in the process,
|
||||||
|
# these parameters have been updated to include things like the hostname and port the
|
||||||
|
# service can be found at.
|
||||||
worker_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
worker_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||||
worker_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
worker_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
||||||
worker_params.command += ['--num_workers', '{}'.format(worker_params.num_replicas)]
|
worker_params.command += ['--num_workers', '{}'.format(worker_params.num_replicas)]
|
||||||
|
|
||||||
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
|
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
|
||||||
|
|
||||||
|
# TODO: instead of defining each container and template spec from scratch, loaded default
|
||||||
|
# configuration and modify them as necessary depending on the store type
|
||||||
if self.params.data_store_params.store_type == "nfs":
|
if self.params.data_store_params.store_type == "nfs":
|
||||||
container = k8sclient.V1Container(
|
container = k8sclient.V1Container(
|
||||||
name=name,
|
name=name,
|
||||||
@@ -252,7 +292,7 @@ class Kubernetes(Deploy):
|
|||||||
restart_policy='Never'
|
restart_policy='Never'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
elif self.params.data_store_params.store_type == "s3":
|
||||||
container = k8sclient.V1Container(
|
container = k8sclient.V1Container(
|
||||||
name=name,
|
name=name,
|
||||||
image=worker_params.image,
|
image=worker_params.image,
|
||||||
@@ -271,6 +311,32 @@ class Kubernetes(Deploy):
|
|||||||
restart_policy='Never'
|
restart_policy='Never'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif self.params.data_store_params.store_type == "redis":
|
||||||
|
container = k8sclient.V1Container(
|
||||||
|
name=name,
|
||||||
|
image=worker_params.image,
|
||||||
|
command=worker_params.command,
|
||||||
|
args=worker_params.arguments,
|
||||||
|
image_pull_policy='Always',
|
||||||
|
stdin=True,
|
||||||
|
tty=True,
|
||||||
|
resources=k8sclient.V1ResourceRequirements(
|
||||||
|
limits={
|
||||||
|
"cpu": "8",
|
||||||
|
"memory": "4Gi",
|
||||||
|
# "nvidia.com/gpu": "0",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
template = k8sclient.V1PodTemplateSpec(
|
||||||
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
|
spec=k8sclient.V1PodSpec(
|
||||||
|
containers=[container],
|
||||||
|
restart_policy='Never'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('unexpected store type {}'.format(self.params.data_store_params.store_type))
|
||||||
|
|
||||||
job_spec = k8sclient.V1JobSpec(
|
job_spec = k8sclient.V1JobSpec(
|
||||||
completions=worker_params.num_replicas,
|
completions=worker_params.num_replicas,
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
|||||||
# Graph Scheduling #
|
# Graph Scheduling #
|
||||||
####################
|
####################
|
||||||
schedule_params = ScheduleParameters()
|
schedule_params = ScheduleParameters()
|
||||||
schedule_params.improve_steps = TrainingSteps(10000000000)
|
schedule_params.improve_steps = TrainingSteps(1e10)
|
||||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2000)
|
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(4000)
|
||||||
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
|
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
|
||||||
schedule_params.heatup_steps = EnvironmentSteps(0)
|
schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||||
|
|
||||||
@@ -22,8 +22,8 @@ schedule_params.heatup_steps = EnvironmentSteps(0)
|
|||||||
# Agent #
|
# Agent #
|
||||||
#########
|
#########
|
||||||
agent_params = PPOAgentParameters()
|
agent_params = PPOAgentParameters()
|
||||||
agent_params.network_wrappers['actor'].learning_rate = 0.001
|
agent_params.network_wrappers['actor'].learning_rate = 5e-5
|
||||||
agent_params.network_wrappers['critic'].learning_rate = 0.001
|
agent_params.network_wrappers['critic'].learning_rate = 5e-5
|
||||||
|
|
||||||
agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense(64)]
|
agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense(64)]
|
||||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(64)]
|
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(64)]
|
||||||
@@ -33,6 +33,9 @@ agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(64
|
|||||||
agent_params.input_filter = InputFilter()
|
agent_params.input_filter = InputFilter()
|
||||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||||
|
|
||||||
|
agent_params.algorithm.initial_kl_coefficient = 0.2
|
||||||
|
agent_params.algorithm.gae_lambda = 1.0
|
||||||
|
|
||||||
# Distributed Coach synchronization type.
|
# Distributed Coach synchronization type.
|
||||||
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
|
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
|
||||||
|
|
||||||
@@ -55,5 +58,3 @@ preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
|||||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||||
preset_validation_params=preset_validation_params)
|
preset_validation_params=preset_validation_params)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,13 +23,13 @@ this rollout worker:
|
|||||||
- exits
|
- exits
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import math
|
|
||||||
|
|
||||||
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
||||||
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
|
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
|
||||||
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
|
||||||
from rl_coach.data_stores.data_store import SyncFiles
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
|
||||||
|
|
||||||
@@ -56,18 +56,6 @@ def wait_for(wait_func, data_store=None, timeout=10):
|
|||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
|
|
||||||
"""
|
|
||||||
block until there is a checkpoint in checkpoint_dir
|
|
||||||
"""
|
|
||||||
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
|
|
||||||
|
|
||||||
def wait():
|
|
||||||
return chkpt_state_file.read() is not None
|
|
||||||
|
|
||||||
wait_for(wait, data_store, timeout)
|
|
||||||
|
|
||||||
|
|
||||||
def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
|
def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
|
||||||
"""
|
"""
|
||||||
Block until trainer is ready
|
Block until trainer is ready
|
||||||
@@ -79,48 +67,38 @@ def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
|
|||||||
wait_for(wait, data_store, timeout)
|
wait_for(wait, data_store, timeout)
|
||||||
|
|
||||||
|
|
||||||
def should_stop(checkpoint_dir):
|
|
||||||
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
|
|
||||||
|
|
||||||
|
|
||||||
def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||||
"""
|
"""
|
||||||
wait for first checkpoint then perform rollouts using the model
|
wait for first checkpoint then perform rollouts using the model
|
||||||
"""
|
"""
|
||||||
checkpoint_dir = task_parameters.checkpoint_restore_path
|
|
||||||
wait_for_checkpoint(checkpoint_dir, data_store)
|
|
||||||
wait_for_trainer_ready(checkpoint_dir, data_store)
|
wait_for_trainer_ready(checkpoint_dir, data_store)
|
||||||
|
if (
|
||||||
|
graph_manager.agent_params.algorithm.distributed_coach_synchronization_type
|
||||||
|
== DistributedCoachSynchronizationType.SYNC
|
||||||
|
):
|
||||||
|
timeout = float("inf")
|
||||||
|
else:
|
||||||
|
timeout = None
|
||||||
|
|
||||||
|
# this could probably be moved up into coach.py
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|
||||||
|
data_store.load_policy(graph_manager, require_new_policy=False, timeout=60)
|
||||||
|
|
||||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||||
|
|
||||||
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
|
|
||||||
last_checkpoint = chkpt_state_reader.get_latest().num
|
|
||||||
|
|
||||||
# this worker should play a fraction of the total playing steps per rollout
|
# this worker should play a fraction of the total playing steps per rollout
|
||||||
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
|
|
||||||
training_steps = (graph_manager.improve_steps / act_steps.num_steps).num_steps
|
|
||||||
for i in range(training_steps):
|
|
||||||
|
|
||||||
if should_stop(checkpoint_dir):
|
act_steps = (
|
||||||
|
graph_manager.agent_params.algorithm.num_consecutive_playing_steps
|
||||||
|
/ num_workers
|
||||||
|
)
|
||||||
|
for i in range(graph_manager.improve_steps / act_steps):
|
||||||
|
if data_store.end_of_policies():
|
||||||
break
|
break
|
||||||
|
|
||||||
graph_manager.act(act_steps, wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes)
|
graph_manager.act(
|
||||||
|
act_steps,
|
||||||
|
wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes,
|
||||||
|
)
|
||||||
|
|
||||||
new_checkpoint = chkpt_state_reader.get_latest()
|
data_store.load_policy(graph_manager, require_new_policy=True, timeout=timeout)
|
||||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
|
||||||
while new_checkpoint is None or new_checkpoint.num < last_checkpoint + 1:
|
|
||||||
if should_stop(checkpoint_dir):
|
|
||||||
break
|
|
||||||
if data_store:
|
|
||||||
data_store.load_from_store()
|
|
||||||
new_checkpoint = chkpt_state_reader.get_latest()
|
|
||||||
|
|
||||||
graph_manager.restore_checkpoint()
|
|
||||||
|
|
||||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
|
|
||||||
if new_checkpoint is not None and new_checkpoint.num > last_checkpoint:
|
|
||||||
graph_manager.restore_checkpoint()
|
|
||||||
|
|
||||||
if new_checkpoint is not None:
|
|
||||||
last_checkpoint = new_checkpoint.num
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -15,12 +15,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
"""
|
from rl_coach.base_parameters import DistributedCoachSynchronizationType
|
||||||
"""
|
|
||||||
import time
|
|
||||||
|
|
||||||
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
|
||||||
from rl_coach import core_types
|
|
||||||
from rl_coach.logger import screen
|
from rl_coach.logger import screen
|
||||||
|
|
||||||
|
|
||||||
@@ -32,22 +27,26 @@ def data_store_ckpt_load(data_store):
|
|||||||
def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test):
|
def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test):
|
||||||
"""
|
"""
|
||||||
restore a checkpoint then perform rollouts using the restored model
|
restore a checkpoint then perform rollouts using the restored model
|
||||||
|
|
||||||
:param graph_manager: An instance of the graph manager
|
:param graph_manager: An instance of the graph manager
|
||||||
|
:param data_store: An instance of DataStore which can be used to communicate policies to roll out workers
|
||||||
:param task_parameters: An instance of task parameters
|
:param task_parameters: An instance of task parameters
|
||||||
:param is_multi_node_test: If this is a multi node test insted of a normal run.
|
:param is_multi_node_test: If this is a multi node test insted of a normal run.
|
||||||
"""
|
"""
|
||||||
# Load checkpoint if provided
|
# Load checkpoint if provided
|
||||||
if task_parameters.checkpoint_restore_path:
|
if task_parameters.checkpoint_restore_path:
|
||||||
data_store_ckpt_load(data_store)
|
data_store_ckpt_load(data_store)
|
||||||
|
|
||||||
# initialize graph
|
# initialize graph
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# initialize graph
|
# initialize graph
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|
||||||
# save randomly initialized graph
|
# save randomly initialized graph
|
||||||
graph_manager.save_checkpoint()
|
data_store.save_policy(graph_manager)
|
||||||
|
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
steps = 0
|
steps = 0
|
||||||
@@ -60,21 +59,17 @@ def training_worker(graph_manager, task_parameters, data_store, is_multi_node_te
|
|||||||
|
|
||||||
while steps < graph_manager.improve_steps.num_steps:
|
while steps < graph_manager.improve_steps.num_steps:
|
||||||
|
|
||||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
|
||||||
if is_multi_node_test and graph_manager.get_current_episodes_count() > graph_manager.preset_validation_params.max_episodes_to_achieve_reward:
|
if is_multi_node_test and graph_manager.get_current_episodes_count() > graph_manager.preset_validation_params.max_episodes_to_achieve_reward:
|
||||||
# Test failed as it has not reached the required success rate
|
# Test failed as it has not reached the required success rate
|
||||||
graph_manager.flush_finished()
|
graph_manager.flush_finished()
|
||||||
screen.error("Could not reach required success by {} episodes.".format(graph_manager.preset_validation_params.max_episodes_to_achieve_reward), crash=True)
|
screen.error("Could not reach required success by {} episodes.".format(graph_manager.preset_validation_params.max_episodes_to_achieve_reward), crash=True)
|
||||||
|
|
||||||
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
|
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
|
||||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
|
||||||
|
|
||||||
if graph_manager.should_train():
|
if graph_manager.should_train():
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
|
||||||
graph_manager.train()
|
graph_manager.train()
|
||||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
|
||||||
|
|
||||||
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
|
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
|
||||||
eval_offset += 1
|
eval_offset += 1
|
||||||
@@ -82,6 +77,10 @@ def training_worker(graph_manager, task_parameters, data_store, is_multi_node_te
|
|||||||
break
|
break
|
||||||
|
|
||||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||||
graph_manager.save_checkpoint()
|
data_store.save_policy(graph_manager)
|
||||||
else:
|
else:
|
||||||
graph_manager.occasionally_save_checkpoint()
|
# NOTE: this implementation conflated occasionally saving checkpoints for later use
|
||||||
|
# in production with checkpoints saved for communication to rollout workers.
|
||||||
|
# TODO: this should be implemented with a new parameter: distributed_coach_synchronization_frequency or similar
|
||||||
|
# graph_manager.occasionally_save_checkpoint()
|
||||||
|
raise NotImplementedError()
|
||||||
|
|||||||
Reference in New Issue
Block a user