1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Simulating the act on the trainer. (#65)

* Remove the use of daemon threads for Redis subscribe.
* Emulate act and observe on trainer side to update internal vars.
This commit is contained in:
Ajay Deshpande
2018-11-15 08:38:58 -08:00
committed by Balaji Subramaniam
parent fe6857eabd
commit fde73ced13
13 changed files with 221 additions and 55 deletions

View File

@@ -2,13 +2,11 @@
import redis
import pickle
import uuid
import threading
import time
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
from rl_coach.core_types import Transition, Episode
from rl_coach.core_types import RunPhase
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
@@ -131,42 +129,40 @@ class RedisPubSubBackend(MemoryBackend):
def sample(self, size):
pass
def fetch(self, num_steps=0):
return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_steps=num_steps)
def subscribe(self, agent):
redis_sub = RedisSub(agent, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
redis_sub.daemon = True
redis_sub.start()
redis_sub = RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
return redis_sub
def get_endpoint(self):
return {'redis_address': self.params.redis_address,
'redis_port': self.params.redis_port}
class RedisSub(threading.Thread):
def __init__(self, agent, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
class RedisSub(object):
def __init__(self, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
super().__init__()
self.redis_connection = redis.Redis(redis_address, redis_port)
self.pubsub = self.redis_connection.pubsub()
self.subscriber = None
self.agent = agent
self.channel = channel
self.subscriber = self.pubsub.subscribe(self.channel)
def run(self):
def run(self, num_steps=0):
steps = 0
for message in self.pubsub.listen():
if message and 'data' in message and self.agent.phase != RunPhase.TEST or self.agent.ap.task_parameters.evaluate_only:
if self.agent.phase == RunPhase.TEST:
print(self.agent.phase)
if message and 'data' in message:
try:
obj = pickle.loads(message['data'])
if type(obj) == Transition:
self.agent.total_steps_counter += 1
self.agent.current_episode_steps_counter += 1
self.agent.call_memory('store', obj)
steps += 1
yield obj
elif type(obj) == Episode:
self.agent.current_episode_buffer = obj
self.agent.total_steps_counter += len(obj.transitions)
self.agent.current_episode_steps_counter += len(obj.transitions)
self.agent.handle_episode_ended()
steps += len(obj.transitions)
yield from obj.transitions
except Exception:
continue
if num_steps > 0 and steps >= num_steps:
break