1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-16 22:25:47 +01:00

waiting for a new checkpoint if it's available

This commit is contained in:
Ajay Deshpande
2018-10-05 19:08:24 -07:00
committed by zach dwiel
parent 5eac0102de
commit 7f00235ed5
7 changed files with 49 additions and 20 deletions

View File

@@ -6,7 +6,6 @@ import threading
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
from rl_coach.memories.memory import Memory
from rl_coach.core_types import Transition, Episode
@@ -126,8 +125,8 @@ class RedisPubSubBackend(MemoryBackend):
def sample(self, size):
pass
def subscribe(self, memory):
redis_sub = RedisSub(memory, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
def subscribe(self, agent):
redis_sub = RedisSub(agent, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
redis_sub.daemon = True
redis_sub.start()
@@ -138,12 +137,12 @@ class RedisPubSubBackend(MemoryBackend):
class RedisSub(threading.Thread):
def __init__(self, memory: Memory, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
def __init__(self, agent, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
super().__init__()
self.redis_connection = redis.Redis(redis_address, redis_port)
self.pubsub = self.redis_connection.pubsub()
self.subscriber = None
self.memory = memory
self.agent = agent
self.channel = channel
self.subscriber = self.pubsub.subscribe(self.channel)
@@ -153,8 +152,13 @@ class RedisSub(threading.Thread):
try:
obj = pickle.loads(message['data'])
if type(obj) == Transition:
self.memory.store(obj)
self.agent.total_steps_counter += 1
self.agent.current_episode_steps_counter += 1
self.agent.call_memory('store', obj)
elif type(obj) == Episode:
self.memory.store_episode(obj)
self.agent.current_episode_buffer = obj
self.agent.total_steps_counter += len(obj.transitions)
self.agent.current_episode_steps_counter += len(obj.transitions)
self.agent.handle_episode_ended()
except Exception:
continue

View File

@@ -184,7 +184,7 @@ class EpisodicExperienceReplay(Memory):
:return: None
"""
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store(episode)
super().store_episode(episode)
if lock:
self.reader_writer_lock.lock_writing_and_reading()

View File

@@ -53,8 +53,7 @@ class Memory(object):
def store_episode(self, episode):
if self.memory_backend:
for transition in episode:
self.memory_backend.store(transition)
self.memory_backend.store(episode)
def get(self, index):
raise NotImplementedError("")