1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Simulating the act on the trainer. (#65)

* Remove the use of daemon threads for Redis subscribe.
* Emulate act and observe on trainer side to update internal vars.
This commit is contained in:
Ajay Deshpande
2018-11-15 08:38:58 -08:00
committed by Balaji Subramaniam
parent fe6857eabd
commit fde73ced13
13 changed files with 221 additions and 55 deletions

View File

@@ -12,7 +12,7 @@ import os
import math
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
from rl_coach.core_types import EnvironmentSteps, RunPhase
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles
@@ -81,21 +81,23 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
task_parameters = TaskParameters()
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
time.sleep(30)
graph_manager.create_graph(task_parameters)
with graph_manager.phase_context(RunPhase.TRAIN):
error_compensation = 100
last_checkpoint = 0
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation)/num_workers)
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
if should_stop(checkpoint_dir):
break
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
if type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentSteps:
graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episode=graph_manager.agent_params.algorithm.act_for_full_episodes)
elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes:
graph_manager.act(EnvironmentEpisodes(num_steps=act_steps))
new_checkpoint = get_latest_checkpoint(checkpoint_dir)