1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

weight for checkpoint before trying to start rollout worker

This commit is contained in:
Zach Dwiel
2018-09-15 00:55:50 +00:00
committed by zach dwiel
parent 4352d6735d
commit f5b7122d56
2 changed files with 35 additions and 4 deletions

View File

@@ -18,6 +18,7 @@ RUN pip3 install --upgrade pip
RUN mkdir /root/src RUN mkdir /root/src
COPY setup.py /root/src/. COPY setup.py /root/src/.
COPY requirements.txt /root/src/.
COPY README.md /root/src/. COPY README.md /root/src/.
WORKDIR /root/src WORKDIR /root/src
RUN pip3 install -e . RUN pip3 install -e .

View File

@@ -1,8 +1,14 @@
""" """
this rollout worker restores a model from disk, evaluates a predefined number of this rollout worker:
episodes, and contributes them to a distributed memory
- restores a model from disk
- evaluates a predefined number of episodes
- contributes them to a distributed memory
- exits
""" """
import argparse import argparse
import time
from rl_coach.base_parameters import TaskParameters from rl_coach.base_parameters import TaskParameters
from rl_coach.coach import expand_preset from rl_coach.coach import expand_preset
@@ -10,13 +16,38 @@ from rl_coach.core_types import EnvironmentEpisodes, RunPhase
from rl_coach.utils import short_dynamic_import from rl_coach.utils import short_dynamic_import
# Q: specify alternative distributed memory, or should this go in the preset? # Q: specify alternative distributed memory, or should this go in the preset?
# A: preset must define distributed memory to be used. we aren't going to take a non-distributed preset and automatically distribute it. # A: preset must define distributed memory to be used. we aren't going to take
# a non-distributed preset and automatically distribute it.
def has_checkpoint(checkpoint_dir):
"""
True if a checkpoint is present in checkpoint_dir
"""
return len(os.listdir(checkpoint_dir)) > 0
def wait_for_checkpoint(checkpoint_dir, timeout=10):
"""
block until there is a checkpoint in checkpoint_dir
"""
for i in range(timeout):
if has_checkpoint(checkpoint_dir):
return
time.sleep(1)
# one last time
if has_checkpoint(checkpoint_dir):
return
raise ValueError(f'checkpoint never found in {checkpoint_dir}')
def rollout_worker(graph_manager, checkpoint_dir): def rollout_worker(graph_manager, checkpoint_dir):
""" """
restore a checkpoint then perform rollouts using the restored model restore a checkpoint then perform rollouts using the restored model
""" """
wait_for_checkpoint(checkpoint_dir)
task_parameters = TaskParameters() task_parameters = TaskParameters()
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
graph_manager.create_graph(task_parameters) graph_manager.create_graph(task_parameters)
@@ -56,6 +87,5 @@ def main():
checkpoint_dir=args.checkpoint_dir, checkpoint_dir=args.checkpoint_dir,
) )
if __name__ == '__main__': if __name__ == '__main__':
main() main()