diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 2e60cfe..1d681cb 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -77,17 +77,14 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters): last_checkpoint = 0 # this worker should play a fraction of the total playing steps per rollout - act_steps = math.ceil(graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps/num_workers) + act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers - for i in range(math.ceil(graph_manager.improve_steps.num_steps/act_steps)): + for i in range(graph_manager.improve_steps / act_steps): if should_stop(checkpoint_dir): break - if type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentSteps: - graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes) - elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes: - graph_manager.act(EnvironmentEpisodes(num_steps=act_steps)) + graph_manager.act(act_steps, wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes) new_checkpoint = chkpt_state_reader.get_latest() if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: