mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
waiting for a new checkpoint if it's available
This commit is contained in:
committed by
zach dwiel
parent
5eac0102de
commit
7f00235ed5
@@ -20,6 +20,8 @@ from rl_coach.core_types import EnvironmentEpisodes, RunPhase
|
||||
from rl_coach.utils import short_dynamic_import
|
||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
|
||||
|
||||
# Q: specify alternative distributed memory, or should this go in the preset?
|
||||
@@ -66,6 +68,20 @@ def data_store_ckpt_load(data_store):
|
||||
data_store.load_from_store()
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
|
||||
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
|
||||
ckpt = CheckpointState()
|
||||
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
|
||||
text_format.Merge(contents, ckpt)
|
||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
||||
current_checkpoint = int(rel_path.split('_Step')[0])
|
||||
if current_checkpoint > last_checkpoint:
|
||||
last_checkpoint = current_checkpoint
|
||||
|
||||
return last_checkpoint
|
||||
|
||||
|
||||
def rollout_worker(graph_manager, checkpoint_dir):
|
||||
"""
|
||||
wait for first checkpoint then perform rollouts using the model
|
||||
@@ -78,9 +94,16 @@ def rollout_worker(graph_manager, checkpoint_dir):
|
||||
graph_manager.create_graph(task_parameters)
|
||||
graph_manager.phase = RunPhase.TRAIN
|
||||
|
||||
last_checkpoint = 0
|
||||
|
||||
for i in range(10000000):
|
||||
graph_manager.restore_checkpoint()
|
||||
graph_manager.act(EnvironmentEpisodes(num_steps=10))
|
||||
graph_manager.act(EnvironmentEpisodes(num_steps=1))
|
||||
|
||||
new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
|
||||
|
||||
if new_checkpoint > last_checkpoint:
|
||||
last_checkpoint = new_checkpoint
|
||||
graph_manager.restore_checkpoint()
|
||||
|
||||
graph_manager.phase = RunPhase.UNDEFINED
|
||||
|
||||
|
||||
Reference in New Issue
Block a user