diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index b249749..f8fcab6 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -81,7 +81,7 @@ class Agent(AgentInterface): self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params) if self.ap.memory.memory_backend_params.run_type == 'trainer': - self.memory_backend.subscribe(self.memory) + self.memory_backend.subscribe(self) else: self.memory.set_memory_backend(self.memory_backend) @@ -523,8 +523,7 @@ class Agent(AgentInterface): Determine if online weights should be copied to the target. :return: boolean: True if the online weights should be copied to the target. """ - if hasattr(self.ap.memory, 'memory_backend_params'): - self.total_steps_counter = self.call_memory('num_transitions') + # update the target network of every network that has a target network step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target if step_method.__class__ == TrainingSteps: @@ -546,7 +545,7 @@ class Agent(AgentInterface): :return: boolean: True if we should start a training phase """ - should_update = self._should_train_helper(wait_for_full_episode) + should_update = self._should_train_helper(wait_for_full_episode=wait_for_full_episode) step_method = self.ap.algorithm.num_consecutive_playing_steps @@ -560,10 +559,8 @@ class Agent(AgentInterface): def _should_train_helper(self, wait_for_full_episode=False): - if hasattr(self.ap.memory, 'memory_backend_params'): - self.total_steps_counter = self.call_memory('num_transitions') - step_method = self.ap.algorithm.num_consecutive_playing_steps + if step_method.__class__ == EnvironmentEpisodes: should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index 0091841..4414966 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -251,6 +251,9 @@ class ClippedPPOAgent(ActorCriticAgent): # clean memory self.call_memory('clean') + def _should_train_helper(self, wait_for_full_episode=True): + return super()._should_train_helper(True) + def train(self): if self._should_train(wait_for_full_episode=True): for network in self.networks.values(): diff --git a/rl_coach/agents/ppo_agent.py b/rl_coach/agents/ppo_agent.py index 24d2b9f..fdb175e 100644 --- a/rl_coach/agents/ppo_agent.py +++ b/rl_coach/agents/ppo_agent.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -310,6 +310,9 @@ class PPOAgent(ActorCriticAgent): # clean memory self.call_memory('clean') + def _should_train_helper(self): + return super()._should_train_helper(True) + def train(self): loss = 0 if self._should_train(wait_for_full_episode=True): diff --git a/rl_coach/memories/backend/redis.py b/rl_coach/memories/backend/redis.py index f060d45..10faa72 100644 --- a/rl_coach/memories/backend/redis.py +++ b/rl_coach/memories/backend/redis.py @@ -6,7 +6,6 @@ import threading from kubernetes import client from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters -from rl_coach.memories.memory import Memory from rl_coach.core_types import Transition, Episode @@ -126,8 +125,8 @@ class RedisPubSubBackend(MemoryBackend): def sample(self, size): pass - def subscribe(self, memory): - redis_sub = RedisSub(memory, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel) + def subscribe(self, agent): + redis_sub = RedisSub(agent, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel) redis_sub.daemon = True redis_sub.start() @@ -138,12 +137,12 @@ class RedisPubSubBackend(MemoryBackend): class RedisSub(threading.Thread): - def __init__(self, memory: Memory, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"): + def __init__(self, agent, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"): super().__init__() self.redis_connection = redis.Redis(redis_address, redis_port) self.pubsub = self.redis_connection.pubsub() self.subscriber = None - self.memory = memory + self.agent = agent self.channel = channel self.subscriber = self.pubsub.subscribe(self.channel) @@ -153,8 +152,13 @@ class RedisSub(threading.Thread): try: obj = pickle.loads(message['data']) if type(obj) == Transition: - self.memory.store(obj) + self.agent.total_steps_counter += 1 + self.agent.current_episode_steps_counter += 1 + self.agent.call_memory('store', obj) elif type(obj) == Episode: - self.memory.store_episode(obj) + self.agent.current_episode_buffer = obj + self.agent.total_steps_counter += len(obj.transitions) + self.agent.current_episode_steps_counter += len(obj.transitions) + self.agent.handle_episode_ended() except Exception: continue diff --git a/rl_coach/memories/episodic/episodic_experience_replay.py b/rl_coach/memories/episodic/episodic_experience_replay.py index f174a74..ada96f4 100644 --- a/rl_coach/memories/episodic/episodic_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_experience_replay.py @@ -184,7 +184,7 @@ class EpisodicExperienceReplay(Memory): :return: None """ # Calling super.store() so that in case a memory backend is used, the memory backend can store this episode. - super().store(episode) + super().store_episode(episode) if lock: self.reader_writer_lock.lock_writing_and_reading() diff --git a/rl_coach/memories/memory.py b/rl_coach/memories/memory.py index 5c56cd0..75414d4 100644 --- a/rl_coach/memories/memory.py +++ b/rl_coach/memories/memory.py @@ -53,8 +53,7 @@ class Memory(object): def store_episode(self, episode): if self.memory_backend: - for transition in episode: - self.memory_backend.store(transition) + self.memory_backend.store(episode) def get(self, index): raise NotImplementedError("") diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index b2f98d8..a04d23c 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -20,6 +20,8 @@ from rl_coach.core_types import EnvironmentEpisodes, RunPhase from rl_coach.utils import short_dynamic_import from rl_coach.memories.backend.memory_impl import construct_memory_params from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params +from google.protobuf import text_format +from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState # Q: specify alternative distributed memory, or should this go in the preset? @@ -66,6 +68,20 @@ def data_store_ckpt_load(data_store): data_store.load_from_store() time.sleep(10) + +def check_for_new_checkpoint(checkpoint_dir, last_checkpoint): + if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')): + ckpt = CheckpointState() + contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read() + text_format.Merge(contents, ckpt) + rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir) + current_checkpoint = int(rel_path.split('_Step')[0]) + if current_checkpoint > last_checkpoint: + last_checkpoint = current_checkpoint + + return last_checkpoint + + def rollout_worker(graph_manager, checkpoint_dir): """ wait for first checkpoint then perform rollouts using the model @@ -78,9 +94,16 @@ def rollout_worker(graph_manager, checkpoint_dir): graph_manager.create_graph(task_parameters) graph_manager.phase = RunPhase.TRAIN + last_checkpoint = 0 + for i in range(10000000): - graph_manager.restore_checkpoint() - graph_manager.act(EnvironmentEpisodes(num_steps=10)) + graph_manager.act(EnvironmentEpisodes(num_steps=1)) + + new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint) + + if new_checkpoint > last_checkpoint: + last_checkpoint = new_checkpoint + graph_manager.restore_checkpoint() graph_manager.phase = RunPhase.UNDEFINED