mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Uploading checkpoint if crd provided (#191)
* Uploading checkpoint if crd provided * Changing the calculation of total steps because of a recent change in core_types Fixes #195
This commit is contained in:
committed by
Scott Leishman
parent
b3db9ce77d
commit
33dc29ee99
@@ -24,24 +24,30 @@ from rl_coach import core_types
|
||||
from rl_coach.logger import screen
|
||||
|
||||
|
||||
def data_store_ckpt_save(data_store):
|
||||
while True:
|
||||
data_store.save_to_store()
|
||||
time.sleep(10)
|
||||
def data_store_ckpt_load(data_store):
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
|
||||
|
||||
def training_worker(graph_manager, task_parameters, is_multi_node_test):
|
||||
def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test):
|
||||
"""
|
||||
restore a checkpoint then perform rollouts using the restored model
|
||||
:param graph_manager: An instance of the graph manager
|
||||
:param task_parameters: An instance of task parameters
|
||||
:param is_multi_node_test: If this is a multi node test insted of a normal run.
|
||||
"""
|
||||
# initialize graph
|
||||
graph_manager.create_graph(task_parameters)
|
||||
# Load checkpoint if provided
|
||||
if task_parameters.checkpoint_restore_path:
|
||||
data_store_ckpt_load(data_store)
|
||||
# initialize graph
|
||||
graph_manager.create_graph(task_parameters)
|
||||
|
||||
# save randomly initialized graph
|
||||
graph_manager.save_checkpoint()
|
||||
else:
|
||||
# initialize graph
|
||||
graph_manager.create_graph(task_parameters)
|
||||
|
||||
# save randomly initialized graph
|
||||
graph_manager.save_checkpoint()
|
||||
|
||||
# training loop
|
||||
steps = 0
|
||||
@@ -50,6 +56,7 @@ def training_worker(graph_manager, task_parameters, is_multi_node_test):
|
||||
eval_offset = 1
|
||||
|
||||
graph_manager.setup_memory_backend()
|
||||
graph_manager.signal_ready()
|
||||
|
||||
while steps < graph_manager.improve_steps.num_steps:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user