From 4da56b1ff2354f316ef2afd9c650a37e4e198fdd Mon Sep 17 00:00:00 2001 From: x77a1 Date: Mon, 19 Nov 2018 08:35:03 -0800 Subject: [PATCH] Enable setting the data store factory in Graph manager (#110) * Enable setting the data store factory in Graph manager This change enables us to use custom data store for storing and retrieving models. We currently need this to have use a data store that loads temporary AWS credentials from disk before calling store or load operations. * Removed data store factory and introduced data store as a attribute --- rl_coach/graph_managers/graph_manager.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 91eccc8..94a1250 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -32,7 +32,7 @@ from rl_coach.environments.environment import Environment from rl_coach.level_manager import LevelManager from rl_coach.logger import screen, Logger from rl_coach.utils import set_cpu, start_shell_command_and_wait -from rl_coach.data_stores.data_store_impl import get_data_store +from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.data_stores.data_store import SyncFiles @@ -115,6 +115,7 @@ class GraphManager(object): self.checkpoint_saver = None self.graph_logger = Logger() + self.data_store = None def create_graph(self, task_parameters: TaskParameters=TaskParameters()): self.graph_creation_time = time.time() @@ -408,7 +409,7 @@ class GraphManager(object): if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'): if self.agent_params.memory.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER): - data_store = get_data_store(self.data_store_params) + data_store = self.get_data_store(self.data_store_params) data_store.load_from_store() # perform several steps of playing @@ -484,7 +485,7 @@ class GraphManager(object): if self.task_parameters.checkpoint_save_dir: open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close() if hasattr(self, 'data_store_params'): - data_store = get_data_store(self.data_store_params) + data_store = self.get_data_store(self.data_store_params) data_store.save_to_store() screen.success("Reached required success rate. Exiting.") @@ -597,7 +598,7 @@ class GraphManager(object): self.last_checkpoint_saving_time = time.time() if hasattr(self, 'data_store_params'): - data_store = get_data_store(self.data_store_params) + data_store = self.get_data_store(self.data_store_params) data_store.save_to_store() def verify_graph_was_created(self): @@ -668,3 +669,9 @@ class GraphManager(object): def should_stop(self) -> bool: return self.task_parameters.apply_stop_condition and all([manager.should_stop() for manager in self.level_managers]) + + def get_data_store(self, param): + if self.data_store: + return self.data_store + + return data_store_creator(param)