mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
waiting for a new checkpoint if it's available
This commit is contained in:
committed by
zach dwiel
parent
5eac0102de
commit
7f00235ed5
@@ -81,7 +81,7 @@ class Agent(AgentInterface):
|
||||
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
|
||||
|
||||
if self.ap.memory.memory_backend_params.run_type == 'trainer':
|
||||
self.memory_backend.subscribe(self.memory)
|
||||
self.memory_backend.subscribe(self)
|
||||
else:
|
||||
self.memory.set_memory_backend(self.memory_backend)
|
||||
|
||||
@@ -523,8 +523,7 @@ class Agent(AgentInterface):
|
||||
Determine if online weights should be copied to the target.
|
||||
:return: boolean: True if the online weights should be copied to the target.
|
||||
"""
|
||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||
self.total_steps_counter = self.call_memory('num_transitions')
|
||||
|
||||
# update the target network of every network that has a target network
|
||||
step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
|
||||
if step_method.__class__ == TrainingSteps:
|
||||
@@ -546,7 +545,7 @@ class Agent(AgentInterface):
|
||||
:return: boolean: True if we should start a training phase
|
||||
"""
|
||||
|
||||
should_update = self._should_train_helper(wait_for_full_episode)
|
||||
should_update = self._should_train_helper(wait_for_full_episode=wait_for_full_episode)
|
||||
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
|
||||
@@ -560,10 +559,8 @@ class Agent(AgentInterface):
|
||||
|
||||
def _should_train_helper(self, wait_for_full_episode=False):
|
||||
|
||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||
self.total_steps_counter = self.call_memory('num_transitions')
|
||||
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
|
||||
if step_method.__class__ == EnvironmentEpisodes:
|
||||
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
||||
|
||||
|
||||
@@ -251,6 +251,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
# clean memory
|
||||
self.call_memory('clean')
|
||||
|
||||
def _should_train_helper(self, wait_for_full_episode=True):
|
||||
return super()._should_train_helper(True)
|
||||
|
||||
def train(self):
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
for network in self.networks.values():
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -310,6 +310,9 @@ class PPOAgent(ActorCriticAgent):
|
||||
# clean memory
|
||||
self.call_memory('clean')
|
||||
|
||||
def _should_train_helper(self):
|
||||
return super()._should_train_helper(True)
|
||||
|
||||
def train(self):
|
||||
loss = 0
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
|
||||
@@ -6,7 +6,6 @@ import threading
|
||||
from kubernetes import client
|
||||
|
||||
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
||||
from rl_coach.memories.memory import Memory
|
||||
from rl_coach.core_types import Transition, Episode
|
||||
|
||||
|
||||
@@ -126,8 +125,8 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
def sample(self, size):
|
||||
pass
|
||||
|
||||
def subscribe(self, memory):
|
||||
redis_sub = RedisSub(memory, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
|
||||
def subscribe(self, agent):
|
||||
redis_sub = RedisSub(agent, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
|
||||
redis_sub.daemon = True
|
||||
redis_sub.start()
|
||||
|
||||
@@ -138,12 +137,12 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
|
||||
class RedisSub(threading.Thread):
|
||||
|
||||
def __init__(self, memory: Memory, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
|
||||
def __init__(self, agent, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
|
||||
super().__init__()
|
||||
self.redis_connection = redis.Redis(redis_address, redis_port)
|
||||
self.pubsub = self.redis_connection.pubsub()
|
||||
self.subscriber = None
|
||||
self.memory = memory
|
||||
self.agent = agent
|
||||
self.channel = channel
|
||||
self.subscriber = self.pubsub.subscribe(self.channel)
|
||||
|
||||
@@ -153,8 +152,13 @@ class RedisSub(threading.Thread):
|
||||
try:
|
||||
obj = pickle.loads(message['data'])
|
||||
if type(obj) == Transition:
|
||||
self.memory.store(obj)
|
||||
self.agent.total_steps_counter += 1
|
||||
self.agent.current_episode_steps_counter += 1
|
||||
self.agent.call_memory('store', obj)
|
||||
elif type(obj) == Episode:
|
||||
self.memory.store_episode(obj)
|
||||
self.agent.current_episode_buffer = obj
|
||||
self.agent.total_steps_counter += len(obj.transitions)
|
||||
self.agent.current_episode_steps_counter += len(obj.transitions)
|
||||
self.agent.handle_episode_ended()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -184,7 +184,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
:return: None
|
||||
"""
|
||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
|
||||
super().store(episode)
|
||||
super().store_episode(episode)
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
@@ -53,8 +53,7 @@ class Memory(object):
|
||||
|
||||
def store_episode(self, episode):
|
||||
if self.memory_backend:
|
||||
for transition in episode:
|
||||
self.memory_backend.store(transition)
|
||||
self.memory_backend.store(episode)
|
||||
|
||||
def get(self, index):
|
||||
raise NotImplementedError("")
|
||||
|
||||
@@ -20,6 +20,8 @@ from rl_coach.core_types import EnvironmentEpisodes, RunPhase
|
||||
from rl_coach.utils import short_dynamic_import
|
||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
|
||||
|
||||
# Q: specify alternative distributed memory, or should this go in the preset?
|
||||
@@ -66,6 +68,20 @@ def data_store_ckpt_load(data_store):
|
||||
data_store.load_from_store()
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
|
||||
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
|
||||
ckpt = CheckpointState()
|
||||
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
|
||||
text_format.Merge(contents, ckpt)
|
||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
||||
current_checkpoint = int(rel_path.split('_Step')[0])
|
||||
if current_checkpoint > last_checkpoint:
|
||||
last_checkpoint = current_checkpoint
|
||||
|
||||
return last_checkpoint
|
||||
|
||||
|
||||
def rollout_worker(graph_manager, checkpoint_dir):
|
||||
"""
|
||||
wait for first checkpoint then perform rollouts using the model
|
||||
@@ -78,9 +94,16 @@ def rollout_worker(graph_manager, checkpoint_dir):
|
||||
graph_manager.create_graph(task_parameters)
|
||||
graph_manager.phase = RunPhase.TRAIN
|
||||
|
||||
last_checkpoint = 0
|
||||
|
||||
for i in range(10000000):
|
||||
graph_manager.restore_checkpoint()
|
||||
graph_manager.act(EnvironmentEpisodes(num_steps=10))
|
||||
graph_manager.act(EnvironmentEpisodes(num_steps=1))
|
||||
|
||||
new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
|
||||
|
||||
if new_checkpoint > last_checkpoint:
|
||||
last_checkpoint = new_checkpoint
|
||||
graph_manager.restore_checkpoint()
|
||||
|
||||
graph_manager.phase = RunPhase.UNDEFINED
|
||||
|
||||
|
||||
Reference in New Issue
Block a user