diff --git a/docker/Dockerfile b/docker/Dockerfile index 9966943..d29c4f6 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -18,6 +18,7 @@ RUN pip3 install --upgrade pip RUN mkdir /root/src COPY setup.py /root/src/. +COPY requirements.txt /root/src/. COPY README.md /root/src/. WORKDIR /root/src RUN pip3 install -e . diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index ef0ff7d..1624041 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -1,8 +1,14 @@ """ -this rollout worker restores a model from disk, evaluates a predefined number of -episodes, and contributes them to a distributed memory +this rollout worker: + +- restores a model from disk +- evaluates a predefined number of episodes +- contributes them to a distributed memory +- exits """ + import argparse +import time from rl_coach.base_parameters import TaskParameters from rl_coach.coach import expand_preset @@ -10,13 +16,38 @@ from rl_coach.core_types import EnvironmentEpisodes, RunPhase from rl_coach.utils import short_dynamic_import # Q: specify alternative distributed memory, or should this go in the preset? -# A: preset must define distributed memory to be used. we aren't going to take a non-distributed preset and automatically distribute it. +# A: preset must define distributed memory to be used. we aren't going to take +# a non-distributed preset and automatically distribute it. + +def has_checkpoint(checkpoint_dir): + """ + True if a checkpoint is present in checkpoint_dir + """ + return len(os.listdir(checkpoint_dir)) > 0 + + +def wait_for_checkpoint(checkpoint_dir, timeout=10): + """ + block until there is a checkpoint in checkpoint_dir + """ + for i in range(timeout): + if has_checkpoint(checkpoint_dir): + return + time.sleep(1) + + # one last time + if has_checkpoint(checkpoint_dir): + return + + raise ValueError(f'checkpoint never found in {checkpoint_dir}') def rollout_worker(graph_manager, checkpoint_dir): """ restore a checkpoint then perform rollouts using the restored model """ + wait_for_checkpoint(checkpoint_dir) + task_parameters = TaskParameters() task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir graph_manager.create_graph(task_parameters) @@ -56,6 +87,5 @@ def main(): checkpoint_dir=args.checkpoint_dir, ) - if __name__ == '__main__': main()