1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00
Files
coach/rl_coach/rollout_worker.py
2019-11-03 14:42:51 +02:00

106 lines
3.0 KiB
Python

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
this rollout worker:
- restores a model from disk
- evaluates a predefined number of episodes
- contributes them to a distributed memory
- exits
"""
import time
import os
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
from rl_coach.data_stores.data_store import SyncFiles
from rl_coach.core_types import RunPhase
def wait_for(wait_func, data_store=None, timeout=10):
"""
block until wait_func is true
"""
for i in range(timeout):
if data_store:
data_store.load_from_store()
if wait_func():
return
time.sleep(10)
# one last time
if wait_func():
return
raise ValueError((
'Waited {timeout} seconds, but condition timed out'
).format(
timeout=timeout,
))
def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
"""
Block until trainer is ready
"""
def wait():
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value))
wait_for(wait, data_store, timeout)
def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
"""
wait for first checkpoint then perform rollouts using the model
"""
if (
graph_manager.agent_params.algorithm.distributed_coach_synchronization_type
== DistributedCoachSynchronizationType.SYNC
):
timeout = float("inf")
else:
timeout = None
# this could probably be moved up into coach.py
graph_manager.create_graph(task_parameters)
data_store.load_policy(graph_manager, require_new_policy=False, timeout=60)
with graph_manager.phase_context(RunPhase.TRAIN):
# this worker should play a fraction of the total playing steps per rollout
graph_manager.reset_internal_state(force_environment_reset=True)
act_steps = (
graph_manager.agent_params.algorithm.num_consecutive_playing_steps
/ num_workers
)
for i in range(graph_manager.improve_steps / act_steps):
if data_store.end_of_policies():
break
graph_manager.act(
act_steps,
wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes,
)
data_store.load_policy(graph_manager, require_new_policy=True, timeout=timeout)