mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
simplify rollout worker steps with new magic methods on StepMethod
This commit is contained in:
@@ -77,17 +77,14 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
|||||||
last_checkpoint = 0
|
last_checkpoint = 0
|
||||||
|
|
||||||
# this worker should play a fraction of the total playing steps per rollout
|
# this worker should play a fraction of the total playing steps per rollout
|
||||||
act_steps = math.ceil(graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps/num_workers)
|
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
|
||||||
|
|
||||||
for i in range(math.ceil(graph_manager.improve_steps.num_steps/act_steps)):
|
for i in range(graph_manager.improve_steps / act_steps):
|
||||||
|
|
||||||
if should_stop(checkpoint_dir):
|
if should_stop(checkpoint_dir):
|
||||||
break
|
break
|
||||||
|
|
||||||
if type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentSteps:
|
graph_manager.act(act_steps, wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes)
|
||||||
graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes)
|
|
||||||
elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes:
|
|
||||||
graph_manager.act(EnvironmentEpisodes(num_steps=act_steps))
|
|
||||||
|
|
||||||
new_checkpoint = chkpt_state_reader.get_latest()
|
new_checkpoint = chkpt_state_reader.get_latest()
|
||||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||||
|
|||||||
Reference in New Issue
Block a user