mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding improvements
This commit is contained in:
committed by
zach dwiel
parent
3ba0df7d07
commit
9a30c26469
@@ -111,7 +111,7 @@ RUN mkdir -p ~/.mujoco \
|
|||||||
&& wget https://www.roboti.us/download/mjpro150_linux.zip -O mujoco.zip \
|
&& wget https://www.roboti.us/download/mjpro150_linux.zip -O mujoco.zip \
|
||||||
&& unzip mujoco.zip -d ~/.mujoco \
|
&& unzip mujoco.zip -d ~/.mujoco \
|
||||||
&& rm mujoco.zip
|
&& rm mujoco.zip
|
||||||
COPY ./mjkey.txt /root/.mujoco/
|
COPY ./README.md ./mjkey.txt /root/.mujoco/
|
||||||
ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:$LD_LIBRARY_PATH
|
ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \
|
RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \
|
||||||
|
|||||||
@@ -161,7 +161,6 @@ class Agent(AgentInterface):
|
|||||||
self.discounted_return = self.register_signal('Discounted Return')
|
self.discounted_return = self.register_signal('Discounted Return')
|
||||||
if isinstance(self.in_action_space, GoalsSpace):
|
if isinstance(self.in_action_space, GoalsSpace):
|
||||||
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
||||||
|
|
||||||
# use seed
|
# use seed
|
||||||
if self.ap.task_parameters.seed is not None:
|
if self.ap.task_parameters.seed is not None:
|
||||||
random.seed(self.ap.task_parameters.seed)
|
random.seed(self.ap.task_parameters.seed)
|
||||||
|
|||||||
@@ -144,7 +144,6 @@ class NFSDataStore(DataStore):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def create_k8s_nfs_resources(self) -> bool:
|
def create_k8s_nfs_resources(self) -> bool:
|
||||||
pv_name = "nfs-ckpt-pv"
|
pv_name = "nfs-ckpt-pv"
|
||||||
persistent_volume = k8sclient.V1PersistentVolume(
|
persistent_volume = k8sclient.V1PersistentVolume(
|
||||||
@@ -219,7 +218,6 @@ class NFSDataStore(DataStore):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def delete_k8s_nfs_resources(self) -> bool:
|
def delete_k8s_nfs_resources(self) -> bool:
|
||||||
del_options = k8sclient.V1DeleteOptions()
|
del_options = k8sclient.V1DeleteOptions()
|
||||||
k8s_api_client = k8sclient.CoreV1Api()
|
k8s_api_client = k8sclient.CoreV1Api()
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import redis
|
|||||||
import pickle
|
import pickle
|
||||||
import uuid
|
import uuid
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from kubernetes import client
|
from kubernetes import client
|
||||||
|
|
||||||
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
||||||
@@ -31,6 +32,8 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
def __init__(self, params: RedisPubSubMemoryBackendParameters):
|
def __init__(self, params: RedisPubSubMemoryBackendParameters):
|
||||||
self.params = params
|
self.params = params
|
||||||
self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port)
|
self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port)
|
||||||
|
self.redis_server_name = 'redis-server-{}'.format(uuid.uuid4())
|
||||||
|
self.redis_service_name = 'redis-service-{}'.format(uuid.uuid4())
|
||||||
|
|
||||||
def store(self, obj):
|
def store(self, obj):
|
||||||
self.redis_connection.publish(self.params.channel, pickle.dumps(obj))
|
self.redis_connection.publish(self.params.channel, pickle.dumps(obj))
|
||||||
@@ -39,7 +42,9 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
if not self.params.deployed:
|
if not self.params.deployed:
|
||||||
if self.params.orchestrator_type == 'kubernetes':
|
if self.params.orchestrator_type == 'kubernetes':
|
||||||
self.deploy_kubernetes()
|
self.deploy_kubernetes()
|
||||||
self.params.deployed = True
|
|
||||||
|
# Wait till subscribe to the channel is possible or else it will cause delays in the trainer.
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
def deploy_kubernetes(self):
|
def deploy_kubernetes(self):
|
||||||
|
|
||||||
@@ -47,11 +52,11 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
self.params.orchestrator_params['namespace'] = "default"
|
self.params.orchestrator_params['namespace'] = "default"
|
||||||
|
|
||||||
container = client.V1Container(
|
container = client.V1Container(
|
||||||
name="redis-server",
|
name=self.redis_server_name,
|
||||||
image='redis:4-alpine',
|
image='redis:4-alpine',
|
||||||
)
|
)
|
||||||
template = client.V1PodTemplateSpec(
|
template = client.V1PodTemplateSpec(
|
||||||
metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}),
|
metadata=client.V1ObjectMeta(labels={'app': self.redis_server_name}),
|
||||||
spec=client.V1PodSpec(
|
spec=client.V1PodSpec(
|
||||||
containers=[container]
|
containers=[container]
|
||||||
)
|
)
|
||||||
@@ -60,14 +65,14 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
replicas=1,
|
replicas=1,
|
||||||
template=template,
|
template=template,
|
||||||
selector=client.V1LabelSelector(
|
selector=client.V1LabelSelector(
|
||||||
match_labels={'app': 'redis-server'}
|
match_labels={'app': self.redis_server_name}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
deployment = client.V1Deployment(
|
deployment = client.V1Deployment(
|
||||||
api_version='apps/v1',
|
api_version='apps/v1',
|
||||||
kind='Deployment',
|
kind='Deployment',
|
||||||
metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}),
|
metadata=client.V1ObjectMeta(name=self.redis_server_name, labels={'app': self.redis_server_name}),
|
||||||
spec=deployment_spec
|
spec=deployment_spec
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,10 +89,10 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
api_version='v1',
|
api_version='v1',
|
||||||
kind='Service',
|
kind='Service',
|
||||||
metadata=client.V1ObjectMeta(
|
metadata=client.V1ObjectMeta(
|
||||||
name='redis-service'
|
name=self.redis_service_name
|
||||||
),
|
),
|
||||||
spec=client.V1ServiceSpec(
|
spec=client.V1ServiceSpec(
|
||||||
selector={'app': 'redis-server'},
|
selector={'app': self.redis_server_name},
|
||||||
ports=[client.V1ServicePort(
|
ports=[client.V1ServicePort(
|
||||||
protocol='TCP',
|
protocol='TCP',
|
||||||
port=6379,
|
port=6379,
|
||||||
@@ -98,7 +103,9 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
core_v1_api.create_namespaced_service(self.params.orchestrator_params['namespace'], service)
|
core_v1_api.create_namespaced_service(self.params.orchestrator_params['namespace'], service)
|
||||||
self.params.redis_address = 'redis-service.{}.svc'.format(self.params.orchestrator_params['namespace'])
|
self.params.redis_address = '{}.{}.svc'.format(
|
||||||
|
self.redis_service_name, self.params.orchestrator_params['namespace']
|
||||||
|
)
|
||||||
self.params.redis_port = 6379
|
self.params.redis_port = 6379
|
||||||
return True
|
return True
|
||||||
except client.rest.ApiException as e:
|
except client.rest.ApiException as e:
|
||||||
@@ -106,23 +113,21 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def undeploy(self):
|
def undeploy(self):
|
||||||
if not self.params.deployed:
|
if self.params.deployed:
|
||||||
return
|
return
|
||||||
api_client = client.AppsV1Api()
|
api_client = client.AppsV1Api()
|
||||||
delete_options = client.V1DeleteOptions()
|
delete_options = client.V1DeleteOptions()
|
||||||
try:
|
try:
|
||||||
api_client.delete_namespaced_deployment('redis-server', self.params.orchestrator_params['namespace'], delete_options)
|
api_client.delete_namespaced_deployment(self.redis_server_name, self.params.orchestrator_params['namespace'], delete_options)
|
||||||
except client.rest.ApiException as e:
|
except client.rest.ApiException as e:
|
||||||
print("Got exception: %s\n while deleting redis-server", e)
|
print("Got exception: %s\n while deleting redis-server", e)
|
||||||
|
|
||||||
api_client = client.CoreV1Api()
|
api_client = client.CoreV1Api()
|
||||||
try:
|
try:
|
||||||
api_client.delete_namespaced_service('redis-service', self.params.orchestrator_params['namespace'], delete_options)
|
api_client.delete_namespaced_service(self.redis_service_name, self.params.orchestrator_params['namespace'], delete_options)
|
||||||
except client.rest.ApiException as e:
|
except client.rest.ApiException as e:
|
||||||
print("Got exception: %s\n while deleting redis-server", e)
|
print("Got exception: %s\n while deleting redis-server", e)
|
||||||
|
|
||||||
self.params.deployed = False
|
|
||||||
|
|
||||||
def sample(self, size):
|
def sample(self, size):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -145,7 +150,9 @@ class RedisSub(threading.Thread):
|
|||||||
self.subscriber = None
|
self.subscriber = None
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
|
print('Before subscribe')
|
||||||
self.subscriber = self.pubsub.subscribe(self.channel)
|
self.subscriber = self.pubsub.subscribe(self.channel)
|
||||||
|
print('After subscribe')
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
for message in self.pubsub.listen():
|
for message in self.pubsub.listen():
|
||||||
|
|||||||
@@ -44,13 +44,11 @@ class EpisodicExperienceReplay(Memory):
|
|||||||
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
||||||
"""
|
"""
|
||||||
super().__init__(max_size)
|
super().__init__(max_size)
|
||||||
|
|
||||||
self._buffer = [Episode()] # list of episodes
|
self._buffer = [Episode()] # list of episodes
|
||||||
self.transitions = []
|
self.transitions = []
|
||||||
self._length = 1 # the episodic replay buffer starts with a single empty episode
|
self._length = 1 # the episodic replay buffer starts with a single empty episode
|
||||||
self._num_transitions = 0
|
self._num_transitions = 0
|
||||||
self._num_transitions_in_complete_episodes = 0
|
self._num_transitions_in_complete_episodes = 0
|
||||||
|
|
||||||
self.reader_writer_lock = ReaderWriterLock()
|
self.reader_writer_lock = ReaderWriterLock()
|
||||||
|
|
||||||
def length(self, lock: bool=False) -> int:
|
def length(self, lock: bool=False) -> int:
|
||||||
|
|||||||
@@ -116,7 +116,9 @@ class Kubernetes(Deploy):
|
|||||||
volume_mounts=[k8sclient.V1VolumeMount(
|
volume_mounts=[k8sclient.V1VolumeMount(
|
||||||
name='nfs-pvc',
|
name='nfs-pvc',
|
||||||
mount_path=trainer_params.checkpoint_dir
|
mount_path=trainer_params.checkpoint_dir
|
||||||
)]
|
)],
|
||||||
|
stdin=True,
|
||||||
|
tty=True
|
||||||
)
|
)
|
||||||
template = k8sclient.V1PodTemplateSpec(
|
template = k8sclient.V1PodTemplateSpec(
|
||||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
@@ -136,7 +138,9 @@ class Kubernetes(Deploy):
|
|||||||
args=trainer_params.arguments,
|
args=trainer_params.arguments,
|
||||||
image_pull_policy='Always',
|
image_pull_policy='Always',
|
||||||
env=[k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key),
|
env=[k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key),
|
||||||
k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)]
|
k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)],
|
||||||
|
stdin=True,
|
||||||
|
tty=True
|
||||||
)
|
)
|
||||||
template = k8sclient.V1PodTemplateSpec(
|
template = k8sclient.V1PodTemplateSpec(
|
||||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
@@ -191,7 +195,9 @@ class Kubernetes(Deploy):
|
|||||||
volume_mounts=[k8sclient.V1VolumeMount(
|
volume_mounts=[k8sclient.V1VolumeMount(
|
||||||
name='nfs-pvc',
|
name='nfs-pvc',
|
||||||
mount_path=worker_params.checkpoint_dir
|
mount_path=worker_params.checkpoint_dir
|
||||||
)]
|
)],
|
||||||
|
stdin=True,
|
||||||
|
tty=True
|
||||||
)
|
)
|
||||||
template = k8sclient.V1PodTemplateSpec(
|
template = k8sclient.V1PodTemplateSpec(
|
||||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
@@ -211,7 +217,9 @@ class Kubernetes(Deploy):
|
|||||||
args=worker_params.arguments,
|
args=worker_params.arguments,
|
||||||
image_pull_policy='Always',
|
image_pull_policy='Always',
|
||||||
env=[k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key),
|
env=[k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key),
|
||||||
k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)]
|
k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)],
|
||||||
|
stdin=True,
|
||||||
|
tty=True
|
||||||
)
|
)
|
||||||
template = k8sclient.V1PodTemplateSpec(
|
template = k8sclient.V1PodTemplateSpec(
|
||||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
@@ -273,9 +281,11 @@ class Kubernetes(Deploy):
|
|||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
# Try to tail the pod logs
|
# Try to tail the pod logs
|
||||||
try:
|
try:
|
||||||
print(corev1_api.read_namespaced_pod_log(
|
for line in corev1_api.read_namespaced_pod_log(
|
||||||
pod_name, self.params.namespace, follow=True
|
pod_name, self.params.namespace, follow=True,
|
||||||
), flush=True)
|
_preload_content=False
|
||||||
|
):
|
||||||
|
print(line.decode('utf-8'), flush=True, end='')
|
||||||
except k8sclient.rest.ApiException as e:
|
except k8sclient.rest.ApiException as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,6 @@ def main():
|
|||||||
default='OFF')
|
default='OFF')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
||||||
|
|
||||||
data_store = None
|
data_store = None
|
||||||
|
|||||||
Reference in New Issue
Block a user