1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Adding improvements

This commit is contained in:
Ajay Deshpande
2018-10-15 15:57:10 -07:00
committed by zach dwiel
parent 3ba0df7d07
commit 9a30c26469
7 changed files with 38 additions and 27 deletions

View File

@@ -111,7 +111,7 @@ RUN mkdir -p ~/.mujoco \
&& wget https://www.roboti.us/download/mjpro150_linux.zip -O mujoco.zip \
&& unzip mujoco.zip -d ~/.mujoco \
&& rm mujoco.zip
COPY ./mjkey.txt /root/.mujoco/
COPY ./README.md ./mjkey.txt /root/.mujoco/
ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:$LD_LIBRARY_PATH
RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \

View File

@@ -161,7 +161,6 @@ class Agent(AgentInterface):
self.discounted_return = self.register_signal('Discounted Return')
if isinstance(self.in_action_space, GoalsSpace):
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
# use seed
if self.ap.task_parameters.seed is not None:
random.seed(self.ap.task_parameters.seed)

View File

@@ -144,7 +144,6 @@ class NFSDataStore(DataStore):
return True
def create_k8s_nfs_resources(self) -> bool:
pv_name = "nfs-ckpt-pv"
persistent_volume = k8sclient.V1PersistentVolume(
@@ -219,7 +218,6 @@ class NFSDataStore(DataStore):
return True
def delete_k8s_nfs_resources(self) -> bool:
del_options = k8sclient.V1DeleteOptions()
k8s_api_client = k8sclient.CoreV1Api()

View File

@@ -3,6 +3,7 @@ import redis
import pickle
import uuid
import threading
import time
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
@@ -31,6 +32,8 @@ class RedisPubSubBackend(MemoryBackend):
def __init__(self, params: RedisPubSubMemoryBackendParameters):
self.params = params
self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port)
self.redis_server_name = 'redis-server-{}'.format(uuid.uuid4())
self.redis_service_name = 'redis-service-{}'.format(uuid.uuid4())
def store(self, obj):
self.redis_connection.publish(self.params.channel, pickle.dumps(obj))
@@ -39,7 +42,9 @@ class RedisPubSubBackend(MemoryBackend):
if not self.params.deployed:
if self.params.orchestrator_type == 'kubernetes':
self.deploy_kubernetes()
self.params.deployed = True
# Wait till subscribe to the channel is possible or else it will cause delays in the trainer.
time.sleep(10)
def deploy_kubernetes(self):
@@ -47,11 +52,11 @@ class RedisPubSubBackend(MemoryBackend):
self.params.orchestrator_params['namespace'] = "default"
container = client.V1Container(
name="redis-server",
name=self.redis_server_name,
image='redis:4-alpine',
)
template = client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}),
metadata=client.V1ObjectMeta(labels={'app': self.redis_server_name}),
spec=client.V1PodSpec(
containers=[container]
)
@@ -60,14 +65,14 @@ class RedisPubSubBackend(MemoryBackend):
replicas=1,
template=template,
selector=client.V1LabelSelector(
match_labels={'app': 'redis-server'}
match_labels={'app': self.redis_server_name}
)
)
deployment = client.V1Deployment(
api_version='apps/v1',
kind='Deployment',
metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}),
metadata=client.V1ObjectMeta(name=self.redis_server_name, labels={'app': self.redis_server_name}),
spec=deployment_spec
)
@@ -84,10 +89,10 @@ class RedisPubSubBackend(MemoryBackend):
api_version='v1',
kind='Service',
metadata=client.V1ObjectMeta(
name='redis-service'
name=self.redis_service_name
),
spec=client.V1ServiceSpec(
selector={'app': 'redis-server'},
selector={'app': self.redis_server_name},
ports=[client.V1ServicePort(
protocol='TCP',
port=6379,
@@ -98,7 +103,9 @@ class RedisPubSubBackend(MemoryBackend):
try:
core_v1_api.create_namespaced_service(self.params.orchestrator_params['namespace'], service)
self.params.redis_address = 'redis-service.{}.svc'.format(self.params.orchestrator_params['namespace'])
self.params.redis_address = '{}.{}.svc'.format(
self.redis_service_name, self.params.orchestrator_params['namespace']
)
self.params.redis_port = 6379
return True
except client.rest.ApiException as e:
@@ -106,23 +113,21 @@ class RedisPubSubBackend(MemoryBackend):
return False
def undeploy(self):
if not self.params.deployed:
if self.params.deployed:
return
api_client = client.AppsV1Api()
delete_options = client.V1DeleteOptions()
try:
api_client.delete_namespaced_deployment('redis-server', self.params.orchestrator_params['namespace'], delete_options)
api_client.delete_namespaced_deployment(self.redis_server_name, self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e)
api_client = client.CoreV1Api()
try:
api_client.delete_namespaced_service('redis-service', self.params.orchestrator_params['namespace'], delete_options)
api_client.delete_namespaced_service(self.redis_service_name, self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e)
self.params.deployed = False
def sample(self, size):
pass
@@ -145,7 +150,9 @@ class RedisSub(threading.Thread):
self.subscriber = None
self.agent = agent
self.channel = channel
print('Before subscribe')
self.subscriber = self.pubsub.subscribe(self.channel)
print('After subscribe')
def run(self):
for message in self.pubsub.listen():

View File

@@ -44,13 +44,11 @@ class EpisodicExperienceReplay(Memory):
:param max_size: the maximum number of transitions or episodes to hold in the memory
"""
super().__init__(max_size)
self._buffer = [Episode()] # list of episodes
self.transitions = []
self._length = 1 # the episodic replay buffer starts with a single empty episode
self._num_transitions = 0
self._num_transitions_in_complete_episodes = 0
self.reader_writer_lock = ReaderWriterLock()
def length(self, lock: bool=False) -> int:

View File

@@ -116,7 +116,9 @@ class Kubernetes(Deploy):
volume_mounts=[k8sclient.V1VolumeMount(
name='nfs-pvc',
mount_path=trainer_params.checkpoint_dir
)]
)],
stdin=True,
tty=True
)
template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
@@ -136,7 +138,9 @@ class Kubernetes(Deploy):
args=trainer_params.arguments,
image_pull_policy='Always',
env=[k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key),
k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)]
k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)],
stdin=True,
tty=True
)
template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
@@ -191,7 +195,9 @@ class Kubernetes(Deploy):
volume_mounts=[k8sclient.V1VolumeMount(
name='nfs-pvc',
mount_path=worker_params.checkpoint_dir
)]
)],
stdin=True,
tty=True
)
template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
@@ -211,7 +217,9 @@ class Kubernetes(Deploy):
args=worker_params.arguments,
image_pull_policy='Always',
env=[k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key),
k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)]
k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)],
stdin=True,
tty=True
)
template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
@@ -273,9 +281,11 @@ class Kubernetes(Deploy):
time.sleep(10)
# Try to tail the pod logs
try:
print(corev1_api.read_namespaced_pod_log(
pod_name, self.params.namespace, follow=True
), flush=True)
for line in corev1_api.read_namespaced_pod_log(
pod_name, self.params.namespace, follow=True,
_preload_content=False
):
print(line.decode('utf-8'), flush=True, end='')
except k8sclient.rest.ApiException as e:
pass

View File

@@ -144,7 +144,6 @@ def main():
default='OFF')
args = parser.parse_args()
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
data_store = None