diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 2ec5085..10ee60a 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -659,9 +659,9 @@ class GraphManager(object): self.handle_episode_ended() self.reset_required = True - def fetch_from_worker(self, num_steps=0): + def fetch_from_worker(self, num_consecutive_playing_steps=None): if hasattr(self, 'memory_backend'): - for transition in self.memory_backend.fetch(num_steps): + for transition in self.memory_backend.fetch(num_consecutive_playing_steps): self.emulate_act_on_trainer(EnvironmentSteps(1), transition) def setup_memory_backend(self) -> None: diff --git a/rl_coach/memories/backend/redis.py b/rl_coach/memories/backend/redis.py index 0e405a6..38afbb7 100644 --- a/rl_coach/memories/backend/redis.py +++ b/rl_coach/memories/backend/redis.py @@ -6,7 +6,7 @@ import time from kubernetes import client from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters -from rl_coach.core_types import Transition, Episode +from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes class RedisPubSubMemoryBackendParameters(MemoryBackendParameters): @@ -129,8 +129,8 @@ class RedisPubSubBackend(MemoryBackend): def sample(self, size): pass - def fetch(self, num_steps=0): - return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_steps=num_steps) + def fetch(self, num_consecutive_playing_steps=None): + return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_consecutive_playing_steps) def subscribe(self, agent): redis_sub = RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel) @@ -150,19 +150,30 @@ class RedisSub(object): self.channel = channel self.subscriber = self.pubsub.subscribe(self.channel) - def run(self, num_steps=0): + def run(self, num_consecutive_playing_steps): + transitions = 0 + episodes = 0 steps = 0 for message in self.pubsub.listen(): if message and 'data' in message: try: obj = pickle.loads(message['data']) if type(obj) == Transition: - steps += 1 + transitions += 1 + if obj.game_over: + episodes += 1 yield obj elif type(obj) == Episode: - steps += len(obj.transitions) + episodes += 1 + transitions += len(obj.transitions) yield from obj.transitions except Exception: continue - if num_steps > 0 and steps >= num_steps: - break + + if type(num_consecutive_playing_steps) == EnvironmentSteps: + steps = transitions + if type(num_consecutive_playing_steps) == EnvironmentEpisodes: + steps = episodes + + if steps >= num_consecutive_playing_steps.num_steps: + break diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index b53d290..76d55ea 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -95,7 +95,7 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers): break if type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentSteps: - graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episode=graph_manager.agent_params.algorithm.act_for_full_episodes) + graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes) elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes: graph_manager.act(EnvironmentEpisodes(num_steps=act_steps)) diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index ac2923a..a21d3ea 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -36,7 +36,7 @@ def training_worker(graph_manager, checkpoint_dir): while(steps < graph_manager.improve_steps.num_steps): graph_manager.phase = core_types.RunPhase.TRAIN - graph_manager.fetch_from_worker(num_steps=graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps) + graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) graph_manager.phase = core_types.RunPhase.UNDEFINED if graph_manager.should_train():