1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

waiting for a new checkpoint if it's available

This commit is contained in:
Ajay Deshpande
2018-10-05 19:08:24 -07:00
committed by zach dwiel
parent 5eac0102de
commit 7f00235ed5
7 changed files with 49 additions and 20 deletions

View File

@@ -20,6 +20,8 @@ from rl_coach.core_types import EnvironmentEpisodes, RunPhase
from rl_coach.utils import short_dynamic_import
from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
# Q: specify alternative distributed memory, or should this go in the preset?
@@ -66,6 +68,20 @@ def data_store_ckpt_load(data_store):
data_store.load_from_store()
time.sleep(10)
def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
ckpt = CheckpointState()
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
text_format.Merge(contents, ckpt)
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
current_checkpoint = int(rel_path.split('_Step')[0])
if current_checkpoint > last_checkpoint:
last_checkpoint = current_checkpoint
return last_checkpoint
def rollout_worker(graph_manager, checkpoint_dir):
"""
wait for first checkpoint then perform rollouts using the model
@@ -78,9 +94,16 @@ def rollout_worker(graph_manager, checkpoint_dir):
graph_manager.create_graph(task_parameters)
graph_manager.phase = RunPhase.TRAIN
last_checkpoint = 0
for i in range(10000000):
graph_manager.restore_checkpoint()
graph_manager.act(EnvironmentEpisodes(num_steps=10))
graph_manager.act(EnvironmentEpisodes(num_steps=1))
new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
if new_checkpoint > last_checkpoint:
last_checkpoint = new_checkpoint
graph_manager.restore_checkpoint()
graph_manager.phase = RunPhase.UNDEFINED