mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Enable setting the data store factory in Graph manager (#110)
* Enable setting the data store factory in Graph manager This change enables us to use custom data store for storing and retrieving models. We currently need this to have use a data store that loads temporary AWS credentials from disk before calling store or load operations. * Removed data store factory and introduced data store as a attribute
This commit is contained in:
committed by
Balaji Subramaniam
parent
67a90ee87e
commit
4da56b1ff2
@@ -32,7 +32,7 @@ from rl_coach.environments.environment import Environment
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
@@ -115,6 +115,7 @@ class GraphManager(object):
|
||||
|
||||
self.checkpoint_saver = None
|
||||
self.graph_logger = Logger()
|
||||
self.data_store = None
|
||||
|
||||
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
|
||||
self.graph_creation_time = time.time()
|
||||
@@ -408,7 +409,7 @@ class GraphManager(object):
|
||||
|
||||
if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'):
|
||||
if self.agent_params.memory.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER):
|
||||
data_store = get_data_store(self.data_store_params)
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
data_store.load_from_store()
|
||||
|
||||
# perform several steps of playing
|
||||
@@ -484,7 +485,7 @@ class GraphManager(object):
|
||||
if self.task_parameters.checkpoint_save_dir:
|
||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = get_data_store(self.data_store_params)
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
data_store.save_to_store()
|
||||
|
||||
screen.success("Reached required success rate. Exiting.")
|
||||
@@ -597,7 +598,7 @@ class GraphManager(object):
|
||||
self.last_checkpoint_saving_time = time.time()
|
||||
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = get_data_store(self.data_store_params)
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
data_store.save_to_store()
|
||||
|
||||
def verify_graph_was_created(self):
|
||||
@@ -668,3 +669,9 @@ class GraphManager(object):
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
return self.task_parameters.apply_stop_condition and all([manager.should_stop() for manager in self.level_managers])
|
||||
|
||||
def get_data_store(self, param):
|
||||
if self.data_store:
|
||||
return self.data_store
|
||||
|
||||
return data_store_creator(param)
|
||||
|
||||
Reference in New Issue
Block a user