1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

weight for checkpoint before trying to start rollout worker

This commit is contained in:
Zach Dwiel
2018-09-15 00:55:50 +00:00
committed by zach dwiel
parent 4352d6735d
commit f5b7122d56
2 changed files with 35 additions and 4 deletions

View File

@@ -1,8 +1,14 @@
"""
this rollout worker restores a model from disk, evaluates a predefined number of
episodes, and contributes them to a distributed memory
this rollout worker:
- restores a model from disk
- evaluates a predefined number of episodes
- contributes them to a distributed memory
- exits
"""
import argparse
import time
from rl_coach.base_parameters import TaskParameters
from rl_coach.coach import expand_preset
@@ -10,13 +16,38 @@ from rl_coach.core_types import EnvironmentEpisodes, RunPhase
from rl_coach.utils import short_dynamic_import
# Q: specify alternative distributed memory, or should this go in the preset?
# A: preset must define distributed memory to be used. we aren't going to take a non-distributed preset and automatically distribute it.
# A: preset must define distributed memory to be used. we aren't going to take
# a non-distributed preset and automatically distribute it.
def has_checkpoint(checkpoint_dir):
"""
True if a checkpoint is present in checkpoint_dir
"""
return len(os.listdir(checkpoint_dir)) > 0
def wait_for_checkpoint(checkpoint_dir, timeout=10):
"""
block until there is a checkpoint in checkpoint_dir
"""
for i in range(timeout):
if has_checkpoint(checkpoint_dir):
return
time.sleep(1)
# one last time
if has_checkpoint(checkpoint_dir):
return
raise ValueError(f'checkpoint never found in {checkpoint_dir}')
def rollout_worker(graph_manager, checkpoint_dir):
"""
restore a checkpoint then perform rollouts using the restored model
"""
wait_for_checkpoint(checkpoint_dir)
task_parameters = TaskParameters()
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
graph_manager.create_graph(task_parameters)
@@ -56,6 +87,5 @@ def main():
checkpoint_dir=args.checkpoint_dir,
)
if __name__ == '__main__':
main()