mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Uploading checkpoint if crd provided (#191)
* Uploading checkpoint if crd provided * Changing the calculation of total steps because of a recent change in core_types Fixes #195
This commit is contained in:
committed by
Scott Leishman
parent
b3db9ce77d
commit
33dc29ee99
@@ -33,30 +33,50 @@ from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
def wait_for(wait_func, data_store=None, timeout=10):
|
||||
"""
|
||||
block until wait_func is true
|
||||
"""
|
||||
for i in range(timeout):
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
|
||||
if wait_func():
|
||||
return
|
||||
time.sleep(10)
|
||||
|
||||
# one last time
|
||||
if wait_func():
|
||||
return
|
||||
|
||||
raise ValueError((
|
||||
'Waited {timeout} seconds, but condition timed out'
|
||||
).format(
|
||||
timeout=timeout,
|
||||
))
|
||||
|
||||
|
||||
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
|
||||
"""
|
||||
block until there is a checkpoint in checkpoint_dir
|
||||
"""
|
||||
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
|
||||
for i in range(timeout):
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
|
||||
if chkpt_state_file.read() is not None:
|
||||
return
|
||||
time.sleep(10)
|
||||
def wait():
|
||||
return chkpt_state_file.read() is not None
|
||||
|
||||
# one last time
|
||||
if chkpt_state_file.read() is not None:
|
||||
return
|
||||
wait_for(wait, data_store, timeout)
|
||||
|
||||
raise ValueError((
|
||||
'Waited {timeout} seconds, but checkpoint never found in '
|
||||
'{checkpoint_dir}'
|
||||
).format(
|
||||
timeout=timeout,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
))
|
||||
|
||||
def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
|
||||
"""
|
||||
Block until trainer is ready
|
||||
"""
|
||||
|
||||
def wait():
|
||||
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value))
|
||||
|
||||
wait_for(wait, data_store, timeout)
|
||||
|
||||
|
||||
def should_stop(checkpoint_dir):
|
||||
@@ -69,17 +89,18 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||
"""
|
||||
checkpoint_dir = task_parameters.checkpoint_restore_path
|
||||
wait_for_checkpoint(checkpoint_dir, data_store)
|
||||
wait_for_trainer_ready(checkpoint_dir, data_store)
|
||||
|
||||
graph_manager.create_graph(task_parameters)
|
||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||
|
||||
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
|
||||
last_checkpoint = 0
|
||||
last_checkpoint = chkpt_state_reader.get_latest().num
|
||||
|
||||
# this worker should play a fraction of the total playing steps per rollout
|
||||
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
|
||||
|
||||
for i in range(graph_manager.improve_steps / act_steps):
|
||||
training_steps = (graph_manager.improve_steps / act_steps.num_steps).num_steps
|
||||
for i in range(training_steps):
|
||||
|
||||
if should_stop(checkpoint_dir):
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user