1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Uploading checkpoint if crd provided (#191)

* Uploading checkpoint if crd provided
* Changing the calculation of total steps because of a recent change in core_types

Fixes #195
This commit is contained in:
Ajay Deshpande
2019-04-26 12:27:33 -07:00
committed by Scott Leishman
parent b3db9ce77d
commit 33dc29ee99
8 changed files with 122 additions and 40 deletions

View File

@@ -33,30 +33,50 @@ from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from rl_coach.data_stores.data_store import SyncFiles
def wait_for(wait_func, data_store=None, timeout=10):
"""
block until wait_func is true
"""
for i in range(timeout):
if data_store:
data_store.load_from_store()
if wait_func():
return
time.sleep(10)
# one last time
if wait_func():
return
raise ValueError((
'Waited {timeout} seconds, but condition timed out'
).format(
timeout=timeout,
))
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
"""
block until there is a checkpoint in checkpoint_dir
"""
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
for i in range(timeout):
if data_store:
data_store.load_from_store()
if chkpt_state_file.read() is not None:
return
time.sleep(10)
def wait():
return chkpt_state_file.read() is not None
# one last time
if chkpt_state_file.read() is not None:
return
wait_for(wait, data_store, timeout)
raise ValueError((
'Waited {timeout} seconds, but checkpoint never found in '
'{checkpoint_dir}'
).format(
timeout=timeout,
checkpoint_dir=checkpoint_dir,
))
def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
"""
Block until trainer is ready
"""
def wait():
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value))
wait_for(wait, data_store, timeout)
def should_stop(checkpoint_dir):
@@ -69,17 +89,18 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
"""
checkpoint_dir = task_parameters.checkpoint_restore_path
wait_for_checkpoint(checkpoint_dir, data_store)
wait_for_trainer_ready(checkpoint_dir, data_store)
graph_manager.create_graph(task_parameters)
with graph_manager.phase_context(RunPhase.TRAIN):
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
last_checkpoint = 0
last_checkpoint = chkpt_state_reader.get_latest().num
# this worker should play a fraction of the total playing steps per rollout
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
for i in range(graph_manager.improve_steps / act_steps):
training_steps = (graph_manager.improve_steps / act_steps.num_steps).num_steps
for i in range(training_steps):
if should_stop(checkpoint_dir):
break