1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

simplify rollout worker steps with new magic methods on StepMethod

This commit is contained in:
zach dwiel
2019-04-04 16:13:56 -04:00
committed by Zach Dwiel
parent 2cb078b4c2
commit 54fdfe2da8

View File

@@ -77,17 +77,14 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
last_checkpoint = 0
# this worker should play a fraction of the total playing steps per rollout
act_steps = math.ceil(graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps/num_workers)
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
for i in range(math.ceil(graph_manager.improve_steps.num_steps/act_steps)):
for i in range(graph_manager.improve_steps / act_steps):
if should_stop(checkpoint_dir):
break
if type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentSteps:
graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes)
elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes:
graph_manager.act(EnvironmentEpisodes(num_steps=act_steps))
graph_manager.act(act_steps, wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes)
new_checkpoint = chkpt_state_reader.get_latest()
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: