mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Enable setting the data store factory in Graph manager (#110)
* Enable setting the data store factory in Graph manager This change enables us to use custom data store for storing and retrieving models. We currently need this to have use a data store that loads temporary AWS credentials from disk before calling store or load operations. * Removed data store factory and introduced data store as a attribute
This commit is contained in:
committed by
Balaji Subramaniam
parent
67a90ee87e
commit
4da56b1ff2
@@ -32,7 +32,7 @@ from rl_coach.environments.environment import Environment
|
|||||||
from rl_coach.level_manager import LevelManager
|
from rl_coach.level_manager import LevelManager
|
||||||
from rl_coach.logger import screen, Logger
|
from rl_coach.logger import screen, Logger
|
||||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||||
from rl_coach.data_stores.data_store import SyncFiles
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
|
||||||
@@ -115,6 +115,7 @@ class GraphManager(object):
|
|||||||
|
|
||||||
self.checkpoint_saver = None
|
self.checkpoint_saver = None
|
||||||
self.graph_logger = Logger()
|
self.graph_logger = Logger()
|
||||||
|
self.data_store = None
|
||||||
|
|
||||||
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
|
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
|
||||||
self.graph_creation_time = time.time()
|
self.graph_creation_time = time.time()
|
||||||
@@ -408,7 +409,7 @@ class GraphManager(object):
|
|||||||
|
|
||||||
if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'):
|
if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'):
|
||||||
if self.agent_params.memory.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER):
|
if self.agent_params.memory.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER):
|
||||||
data_store = get_data_store(self.data_store_params)
|
data_store = self.get_data_store(self.data_store_params)
|
||||||
data_store.load_from_store()
|
data_store.load_from_store()
|
||||||
|
|
||||||
# perform several steps of playing
|
# perform several steps of playing
|
||||||
@@ -484,7 +485,7 @@ class GraphManager(object):
|
|||||||
if self.task_parameters.checkpoint_save_dir:
|
if self.task_parameters.checkpoint_save_dir:
|
||||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||||
if hasattr(self, 'data_store_params'):
|
if hasattr(self, 'data_store_params'):
|
||||||
data_store = get_data_store(self.data_store_params)
|
data_store = self.get_data_store(self.data_store_params)
|
||||||
data_store.save_to_store()
|
data_store.save_to_store()
|
||||||
|
|
||||||
screen.success("Reached required success rate. Exiting.")
|
screen.success("Reached required success rate. Exiting.")
|
||||||
@@ -597,7 +598,7 @@ class GraphManager(object):
|
|||||||
self.last_checkpoint_saving_time = time.time()
|
self.last_checkpoint_saving_time = time.time()
|
||||||
|
|
||||||
if hasattr(self, 'data_store_params'):
|
if hasattr(self, 'data_store_params'):
|
||||||
data_store = get_data_store(self.data_store_params)
|
data_store = self.get_data_store(self.data_store_params)
|
||||||
data_store.save_to_store()
|
data_store.save_to_store()
|
||||||
|
|
||||||
def verify_graph_was_created(self):
|
def verify_graph_was_created(self):
|
||||||
@@ -668,3 +669,9 @@ class GraphManager(object):
|
|||||||
|
|
||||||
def should_stop(self) -> bool:
|
def should_stop(self) -> bool:
|
||||||
return self.task_parameters.apply_stop_condition and all([manager.should_stop() for manager in self.level_managers])
|
return self.task_parameters.apply_stop_condition and all([manager.should_stop() for manager in self.level_managers])
|
||||||
|
|
||||||
|
def get_data_store(self, param):
|
||||||
|
if self.data_store:
|
||||||
|
return self.data_store
|
||||||
|
|
||||||
|
return data_store_creator(param)
|
||||||
|
|||||||
Reference in New Issue
Block a user