From 54fdfe2da8538e9bedf8418950e70313dfd3de6e Mon Sep 17 00:00:00 2001 From: zach dwiel Date: Thu, 4 Apr 2019 16:13:56 -0400 Subject: [PATCH] simplify rollout worker steps with new magic methods on StepMethod --- rl_coach/rollout_worker.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 2e60cfe..1d681cb 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -77,17 +77,14 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters): last_checkpoint = 0 # this worker should play a fraction of the total playing steps per rollout - act_steps = math.ceil(graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps/num_workers) + act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers - for i in range(math.ceil(graph_manager.improve_steps.num_steps/act_steps)): + for i in range(graph_manager.improve_steps / act_steps): if should_stop(checkpoint_dir): break - if type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentSteps: - graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes) - elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes: - graph_manager.act(EnvironmentEpisodes(num_steps=act_steps)) + graph_manager.act(act_steps, wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes) new_checkpoint = chkpt_state_reader.get_latest() if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: