1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

waiting for a new checkpoint if it's available

This commit is contained in:
Ajay Deshpande
2018-10-05 19:08:24 -07:00
committed by zach dwiel
parent 5eac0102de
commit 7f00235ed5
7 changed files with 49 additions and 20 deletions

View File

@@ -81,7 +81,7 @@ class Agent(AgentInterface):
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
if self.ap.memory.memory_backend_params.run_type == 'trainer':
self.memory_backend.subscribe(self.memory)
self.memory_backend.subscribe(self)
else:
self.memory.set_memory_backend(self.memory_backend)
@@ -523,8 +523,7 @@ class Agent(AgentInterface):
Determine if online weights should be copied to the target.
:return: boolean: True if the online weights should be copied to the target.
"""
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
# update the target network of every network that has a target network
step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
if step_method.__class__ == TrainingSteps:
@@ -546,7 +545,7 @@ class Agent(AgentInterface):
:return: boolean: True if we should start a training phase
"""
should_update = self._should_train_helper(wait_for_full_episode)
should_update = self._should_train_helper(wait_for_full_episode=wait_for_full_episode)
step_method = self.ap.algorithm.num_consecutive_playing_steps
@@ -560,10 +559,8 @@ class Agent(AgentInterface):
def _should_train_helper(self, wait_for_full_episode=False):
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
step_method = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps

View File

@@ -251,6 +251,9 @@ class ClippedPPOAgent(ActorCriticAgent):
# clean memory
self.call_memory('clean')
def _should_train_helper(self, wait_for_full_episode=True):
return super()._should_train_helper(True)
def train(self):
if self._should_train(wait_for_full_episode=True):
for network in self.networks.values():

View File

@@ -310,6 +310,9 @@ class PPOAgent(ActorCriticAgent):
# clean memory
self.call_memory('clean')
def _should_train_helper(self):
return super()._should_train_helper(True)
def train(self):
loss = 0
if self._should_train(wait_for_full_episode=True):

View File

@@ -6,7 +6,6 @@ import threading
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
from rl_coach.memories.memory import Memory
from rl_coach.core_types import Transition, Episode
@@ -126,8 +125,8 @@ class RedisPubSubBackend(MemoryBackend):
def sample(self, size):
pass
def subscribe(self, memory):
redis_sub = RedisSub(memory, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
def subscribe(self, agent):
redis_sub = RedisSub(agent, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
redis_sub.daemon = True
redis_sub.start()
@@ -138,12 +137,12 @@ class RedisPubSubBackend(MemoryBackend):
class RedisSub(threading.Thread):
def __init__(self, memory: Memory, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
def __init__(self, agent, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
super().__init__()
self.redis_connection = redis.Redis(redis_address, redis_port)
self.pubsub = self.redis_connection.pubsub()
self.subscriber = None
self.memory = memory
self.agent = agent
self.channel = channel
self.subscriber = self.pubsub.subscribe(self.channel)
@@ -153,8 +152,13 @@ class RedisSub(threading.Thread):
try:
obj = pickle.loads(message['data'])
if type(obj) == Transition:
self.memory.store(obj)
self.agent.total_steps_counter += 1
self.agent.current_episode_steps_counter += 1
self.agent.call_memory('store', obj)
elif type(obj) == Episode:
self.memory.store_episode(obj)
self.agent.current_episode_buffer = obj
self.agent.total_steps_counter += len(obj.transitions)
self.agent.current_episode_steps_counter += len(obj.transitions)
self.agent.handle_episode_ended()
except Exception:
continue

View File

@@ -184,7 +184,7 @@ class EpisodicExperienceReplay(Memory):
:return: None
"""
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store(episode)
super().store_episode(episode)
if lock:
self.reader_writer_lock.lock_writing_and_reading()

View File

@@ -53,8 +53,7 @@ class Memory(object):
def store_episode(self, episode):
if self.memory_backend:
for transition in episode:
self.memory_backend.store(transition)
self.memory_backend.store(episode)
def get(self, index):
raise NotImplementedError("")

View File

@@ -20,6 +20,8 @@ from rl_coach.core_types import EnvironmentEpisodes, RunPhase
from rl_coach.utils import short_dynamic_import
from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
# Q: specify alternative distributed memory, or should this go in the preset?
@@ -66,6 +68,20 @@ def data_store_ckpt_load(data_store):
data_store.load_from_store()
time.sleep(10)
def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
ckpt = CheckpointState()
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
text_format.Merge(contents, ckpt)
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
current_checkpoint = int(rel_path.split('_Step')[0])
if current_checkpoint > last_checkpoint:
last_checkpoint = current_checkpoint
return last_checkpoint
def rollout_worker(graph_manager, checkpoint_dir):
"""
wait for first checkpoint then perform rollouts using the model
@@ -78,9 +94,16 @@ def rollout_worker(graph_manager, checkpoint_dir):
graph_manager.create_graph(task_parameters)
graph_manager.phase = RunPhase.TRAIN
last_checkpoint = 0
for i in range(10000000):
graph_manager.restore_checkpoint()
graph_manager.act(EnvironmentEpisodes(num_steps=10))
graph_manager.act(EnvironmentEpisodes(num_steps=1))
new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
if new_checkpoint > last_checkpoint:
last_checkpoint = new_checkpoint
graph_manager.restore_checkpoint()
graph_manager.phase = RunPhase.UNDEFINED