diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 576ac8d..3fe5ed9 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -376,7 +376,7 @@ class GraphManager(object): if self.agent_params.memory.memory_backend_params.run_type == "worker": data_store = get_data_store(self.data_store_params) data_store.load_from_store() - + # perform several steps of playing result = None @@ -435,7 +435,7 @@ class GraphManager(object): if steps.num_steps > 0: self.phase = RunPhase.TRAIN self.reset_internal_state(force_environment_reset=True) - #TODO - the below while loop should end with full episodes, so to avoid situations where we have partial + # TODO - the below while loop should end with full episodes, so to avoid situations where we have partial # episodes in memory count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps while self.total_steps_counters[self.phase][steps.__class__] < count_end: diff --git a/rl_coach/memories/backend/redis.py b/rl_coach/memories/backend/redis.py index 10faa72..80cbae5 100644 --- a/rl_coach/memories/backend/redis.py +++ b/rl_coach/memories/backend/redis.py @@ -7,6 +7,7 @@ from kubernetes import client from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters from rl_coach.core_types import Transition, Episode +from rl_coach.core_types import RunPhase class RedisPubSubMemoryBackendParameters(MemoryBackendParameters): @@ -148,7 +149,9 @@ class RedisSub(threading.Thread): def run(self): for message in self.pubsub.listen(): - if message and 'data' in message: + if message and 'data' in message and self.agent.phase != RunPhase.TEST or self.agent.ap.task_parameters.evaluate_only: + if self.agent.phase == RunPhase.TEST: + print(self.agent.phase) try: obj = pickle.loads(message['data']) if type(obj) == Transition: diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index c5664a8..b68bee7 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -38,7 +38,6 @@ def training_worker(graph_manager, checkpoint_dir): graph_manager.phase = core_types.RunPhase.UNDEFINED graph_manager.evaluate(graph_manager.evaluation_steps) graph_manager.save_checkpoint() - time.sleep(10) def main():