mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Handle both Environment Steps and Episodes on the subscriber side. (#99)
This commit is contained in:
committed by
Scott Leishman
parent
3358e04a6a
commit
101c55d37d
@@ -6,7 +6,7 @@ import time
|
||||
from kubernetes import client
|
||||
|
||||
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
||||
from rl_coach.core_types import Transition, Episode
|
||||
from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes
|
||||
|
||||
|
||||
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
|
||||
@@ -129,8 +129,8 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
def sample(self, size):
|
||||
pass
|
||||
|
||||
def fetch(self, num_steps=0):
|
||||
return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_steps=num_steps)
|
||||
def fetch(self, num_consecutive_playing_steps=None):
|
||||
return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_consecutive_playing_steps)
|
||||
|
||||
def subscribe(self, agent):
|
||||
redis_sub = RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
|
||||
@@ -150,19 +150,30 @@ class RedisSub(object):
|
||||
self.channel = channel
|
||||
self.subscriber = self.pubsub.subscribe(self.channel)
|
||||
|
||||
def run(self, num_steps=0):
|
||||
def run(self, num_consecutive_playing_steps):
|
||||
transitions = 0
|
||||
episodes = 0
|
||||
steps = 0
|
||||
for message in self.pubsub.listen():
|
||||
if message and 'data' in message:
|
||||
try:
|
||||
obj = pickle.loads(message['data'])
|
||||
if type(obj) == Transition:
|
||||
steps += 1
|
||||
transitions += 1
|
||||
if obj.game_over:
|
||||
episodes += 1
|
||||
yield obj
|
||||
elif type(obj) == Episode:
|
||||
steps += len(obj.transitions)
|
||||
episodes += 1
|
||||
transitions += len(obj.transitions)
|
||||
yield from obj.transitions
|
||||
except Exception:
|
||||
continue
|
||||
if num_steps > 0 and steps >= num_steps:
|
||||
break
|
||||
|
||||
if type(num_consecutive_playing_steps) == EnvironmentSteps:
|
||||
steps = transitions
|
||||
if type(num_consecutive_playing_steps) == EnvironmentEpisodes:
|
||||
steps = episodes
|
||||
|
||||
if steps >= num_consecutive_playing_steps.num_steps:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user