1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Enable setting the data store factory in Graph manager (#110)

* Enable setting the data store factory in Graph manager

This change enables us to use custom data store for storing and retrieving models.
We currently need this to have use a data store that loads temporary AWS credentials
from disk before calling store or load operations.

* Removed data store factory and introduced data store as a attribute
This commit is contained in:
x77a1
2018-11-19 08:35:03 -08:00
committed by Balaji Subramaniam
parent 67a90ee87e
commit 4da56b1ff2

View File

@@ -32,7 +32,7 @@ from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
@@ -115,6 +115,7 @@ class GraphManager(object):
self.checkpoint_saver = None
self.graph_logger = Logger()
self.data_store = None
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
self.graph_creation_time = time.time()
@@ -408,7 +409,7 @@ class GraphManager(object):
if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'):
if self.agent_params.memory.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER):
data_store = get_data_store(self.data_store_params)
data_store = self.get_data_store(self.data_store_params)
data_store.load_from_store()
# perform several steps of playing
@@ -484,7 +485,7 @@ class GraphManager(object):
if self.task_parameters.checkpoint_save_dir:
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = get_data_store(self.data_store_params)
data_store = self.get_data_store(self.data_store_params)
data_store.save_to_store()
screen.success("Reached required success rate. Exiting.")
@@ -597,7 +598,7 @@ class GraphManager(object):
self.last_checkpoint_saving_time = time.time()
if hasattr(self, 'data_store_params'):
data_store = get_data_store(self.data_store_params)
data_store = self.get_data_store(self.data_store_params)
data_store.save_to_store()
def verify_graph_was_created(self):
@@ -668,3 +669,9 @@ class GraphManager(object):
def should_stop(self) -> bool:
return self.task_parameters.apply_stop_condition and all([manager.should_stop() for manager in self.level_managers])
def get_data_store(self, param):
if self.data_store:
return self.data_store
return data_store_creator(param)