1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Uploading checkpoint if crd provided (#191)

* Uploading checkpoint if crd provided
* Changing the calculation of total steps because of a recent change in core_types

Fixes #195
This commit is contained in:
Ajay Deshpande
2019-04-26 12:27:33 -07:00
committed by Scott Leishman
parent b3db9ce77d
commit 33dc29ee99
8 changed files with 122 additions and 40 deletions

View File

@@ -24,24 +24,30 @@ from rl_coach import core_types
from rl_coach.logger import screen
def data_store_ckpt_save(data_store):
while True:
data_store.save_to_store()
time.sleep(10)
def data_store_ckpt_load(data_store):
if data_store:
data_store.load_from_store()
def training_worker(graph_manager, task_parameters, is_multi_node_test):
def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test):
"""
restore a checkpoint then perform rollouts using the restored model
:param graph_manager: An instance of the graph manager
:param task_parameters: An instance of task parameters
:param is_multi_node_test: If this is a multi node test insted of a normal run.
"""
# initialize graph
graph_manager.create_graph(task_parameters)
# Load checkpoint if provided
if task_parameters.checkpoint_restore_path:
data_store_ckpt_load(data_store)
# initialize graph
graph_manager.create_graph(task_parameters)
# save randomly initialized graph
graph_manager.save_checkpoint()
else:
# initialize graph
graph_manager.create_graph(task_parameters)
# save randomly initialized graph
graph_manager.save_checkpoint()
# training loop
steps = 0
@@ -50,6 +56,7 @@ def training_worker(graph_manager, task_parameters, is_multi_node_test):
eval_offset = 1
graph_manager.setup_memory_backend()
graph_manager.signal_ready()
while steps < graph_manager.improve_steps.num_steps: