diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 99c1080..145be5d 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -783,10 +783,11 @@ class Agent(AgentInterface): return batches_dict - def act(self) -> ActionInfo: + def act(self, action: Union[None, ActionType]=None) -> ActionInfo: """ Given the agents current knowledge, decide on the next action to apply to the environment + :param action: An action to take, overriding whatever the current policy is :return: An ActionInfo object, which contains the action and any additional info from the action decision process """ if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0: @@ -799,19 +800,23 @@ class Agent(AgentInterface): self.current_episode_steps_counter += 1 # decide on the action - if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions: - # random action - self.last_action_info = self.spaces.action.sample_with_info() - else: - # informed action - if self.pre_network_filter is not None: - # before choosing an action, first use the pre_network_filter to filter out the current state - update_filter_internal_state = self.phase is not RunPhase.TEST - curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state) - + if action is None: + if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions: + # random action + action = self.spaces.action.sample_with_info() else: - curr_state = self.curr_state - self.last_action_info = self.choose_action(curr_state) + # informed action + if self.pre_network_filter is not None: + # before choosing an action, first use the pre_network_filter to filter out the current state + update_filter_internal_state = self.phase is not RunPhase.TEST + curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state) + + else: + curr_state = self.curr_state + action = self.choose_action(curr_state) + assert isinstance(action, ActionInfo) + + self.last_action_info = action # is it intentional that self.last_action_info is not filtered? filtered_action_info = self.output_filter.filter(self.last_action_info) @@ -1036,29 +1041,6 @@ class Agent(AgentInterface): return transition.game_over - # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create - # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] - def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo: - """ - This emulates the act using the transition obtained from the rollout worker on the training worker - in case of distributed training. - Given the agents current knowledge, decide on the next action to apply to the environment - :return: an action and a dictionary containing any additional info from the action decision process - """ - if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0: - # This agent never plays while training (e.g. behavioral cloning) - return None - - # count steps (only when training or if we are in the evaluation worker) - if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only: - self.total_steps_counter += 1 - self.current_episode_steps_counter += 1 - - # these types don't match: ActionInfo = ActionType - self.last_action_info = action - - return self.last_action_info - def get_success_rate(self) -> float: return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed diff --git a/rl_coach/agents/agent_interface.py b/rl_coach/agents/agent_interface.py index 199daf6..7bd550a 100644 --- a/rl_coach/agents/agent_interface.py +++ b/rl_coach/agents/agent_interface.py @@ -142,16 +142,6 @@ class AgentInterface(object): """ raise NotImplementedError("") - # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create - # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] - def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo: - """ - This emulates the act using the transition obtained from the rollout worker on the training worker - in case of distributed training. - :return: A tuple containing the actual action - """ - raise NotImplementedError("") - def collect_savers(self, parent_path_suffix: str) -> SaverCollection: """ Collect all of agent savers diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index a58ed0b..ed60c1e 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -313,7 +313,7 @@ class LevelManager(EnvironmentInterface): # for i in range(self.steps_limit.num_steps): # let the agent observe the result and decide if it wants to terminate the episode done = acting_agent.emulate_observe_on_trainer(transition) - acting_agent.emulate_act_on_trainer(transition.action) + acting_agent.act(transition.action) if done: self.handle_episode_ended()