1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Add RedisDataStore (#295)

* GraphManager.set_session also sets self.sess

* make sure that GraphManager.fetch_from_worker uses training phase

* remove unnecessary phase setting in training worker

* reorganize rollout worker

* provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway

* allow dividing TrainingSteps and EnvironmentSteps

* add timestamps to the log

* added redis data store

* conflict merge fix
This commit is contained in:
Zach Dwiel
2019-08-28 14:15:58 -04:00
committed by shadiendrawis
parent 34e1c04f29
commit 7b0fccb041
18 changed files with 528 additions and 120 deletions

View File

@@ -23,13 +23,13 @@ this rollout worker:
- exits
"""
import time
import os
import math
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from rl_coach.data_stores.data_store import SyncFiles
@@ -56,18 +56,6 @@ def wait_for(wait_func, data_store=None, timeout=10):
))
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
"""
block until there is a checkpoint in checkpoint_dir
"""
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
def wait():
return chkpt_state_file.read() is not None
wait_for(wait, data_store, timeout)
def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
"""
Block until trainer is ready
@@ -79,48 +67,38 @@ def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
wait_for(wait, data_store, timeout)
def should_stop(checkpoint_dir):
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
"""
wait for first checkpoint then perform rollouts using the model
"""
checkpoint_dir = task_parameters.checkpoint_restore_path
wait_for_checkpoint(checkpoint_dir, data_store)
wait_for_trainer_ready(checkpoint_dir, data_store)
if (
graph_manager.agent_params.algorithm.distributed_coach_synchronization_type
== DistributedCoachSynchronizationType.SYNC
):
timeout = float("inf")
else:
timeout = None
# this could probably be moved up into coach.py
graph_manager.create_graph(task_parameters)
data_store.load_policy(graph_manager, require_new_policy=False, timeout=60)
with graph_manager.phase_context(RunPhase.TRAIN):
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
last_checkpoint = chkpt_state_reader.get_latest().num
# this worker should play a fraction of the total playing steps per rollout
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
training_steps = (graph_manager.improve_steps / act_steps.num_steps).num_steps
for i in range(training_steps):
if should_stop(checkpoint_dir):
act_steps = (
graph_manager.agent_params.algorithm.num_consecutive_playing_steps
/ num_workers
)
for i in range(graph_manager.improve_steps / act_steps):
if data_store.end_of_policies():
break
graph_manager.act(act_steps, wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes)
graph_manager.act(
act_steps,
wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes,
)
new_checkpoint = chkpt_state_reader.get_latest()
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
while new_checkpoint is None or new_checkpoint.num < last_checkpoint + 1:
if should_stop(checkpoint_dir):
break
if data_store:
data_store.load_from_store()
new_checkpoint = chkpt_state_reader.get_latest()
graph_manager.restore_checkpoint()
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
if new_checkpoint is not None and new_checkpoint.num > last_checkpoint:
graph_manager.restore_checkpoint()
if new_checkpoint is not None:
last_checkpoint = new_checkpoint.num
data_store.load_policy(graph_manager, require_new_policy=True, timeout=timeout)