mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
waiting for a new checkpoint if it's available
This commit is contained in:
committed by
zach dwiel
parent
5eac0102de
commit
7f00235ed5
@@ -6,7 +6,6 @@ import threading
|
||||
from kubernetes import client
|
||||
|
||||
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
||||
from rl_coach.memories.memory import Memory
|
||||
from rl_coach.core_types import Transition, Episode
|
||||
|
||||
|
||||
@@ -126,8 +125,8 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
def sample(self, size):
|
||||
pass
|
||||
|
||||
def subscribe(self, memory):
|
||||
redis_sub = RedisSub(memory, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
|
||||
def subscribe(self, agent):
|
||||
redis_sub = RedisSub(agent, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
|
||||
redis_sub.daemon = True
|
||||
redis_sub.start()
|
||||
|
||||
@@ -138,12 +137,12 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
|
||||
class RedisSub(threading.Thread):
|
||||
|
||||
def __init__(self, memory: Memory, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
|
||||
def __init__(self, agent, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
|
||||
super().__init__()
|
||||
self.redis_connection = redis.Redis(redis_address, redis_port)
|
||||
self.pubsub = self.redis_connection.pubsub()
|
||||
self.subscriber = None
|
||||
self.memory = memory
|
||||
self.agent = agent
|
||||
self.channel = channel
|
||||
self.subscriber = self.pubsub.subscribe(self.channel)
|
||||
|
||||
@@ -153,8 +152,13 @@ class RedisSub(threading.Thread):
|
||||
try:
|
||||
obj = pickle.loads(message['data'])
|
||||
if type(obj) == Transition:
|
||||
self.memory.store(obj)
|
||||
self.agent.total_steps_counter += 1
|
||||
self.agent.current_episode_steps_counter += 1
|
||||
self.agent.call_memory('store', obj)
|
||||
elif type(obj) == Episode:
|
||||
self.memory.store_episode(obj)
|
||||
self.agent.current_episode_buffer = obj
|
||||
self.agent.total_steps_counter += len(obj.transitions)
|
||||
self.agent.current_episode_steps_counter += len(obj.transitions)
|
||||
self.agent.handle_episode_ended()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user