1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

merge AgentInterface.emulate_act_on_trainer and AgentInterface.act

This commit is contained in:
zach dwiel
2019-04-05 11:49:09 -04:00
committed by Zach Dwiel
parent f2fead57e5
commit f8741522e4
3 changed files with 19 additions and 47 deletions

View File

@@ -783,10 +783,11 @@ class Agent(AgentInterface):
return batches_dict
def act(self) -> ActionInfo:
def act(self, action: Union[None, ActionType]=None) -> ActionInfo:
"""
Given the agents current knowledge, decide on the next action to apply to the environment
:param action: An action to take, overriding whatever the current policy is
:return: An ActionInfo object, which contains the action and any additional info from the action decision process
"""
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
@@ -799,19 +800,23 @@ class Agent(AgentInterface):
self.current_episode_steps_counter += 1
# decide on the action
if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions:
# random action
self.last_action_info = self.spaces.action.sample_with_info()
else:
# informed action
if self.pre_network_filter is not None:
# before choosing an action, first use the pre_network_filter to filter out the current state
update_filter_internal_state = self.phase is not RunPhase.TEST
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
if action is None:
if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions:
# random action
action = self.spaces.action.sample_with_info()
else:
curr_state = self.curr_state
self.last_action_info = self.choose_action(curr_state)
# informed action
if self.pre_network_filter is not None:
# before choosing an action, first use the pre_network_filter to filter out the current state
update_filter_internal_state = self.phase is not RunPhase.TEST
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
else:
curr_state = self.curr_state
action = self.choose_action(curr_state)
assert isinstance(action, ActionInfo)
self.last_action_info = action
# is it intentional that self.last_action_info is not filtered?
filtered_action_info = self.output_filter.filter(self.last_action_info)
@@ -1036,29 +1041,6 @@ class Agent(AgentInterface):
return transition.game_over
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given the agents current knowledge, decide on the next action to apply to the environment
:return: an action and a dictionary containing any additional info from the action decision process
"""
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
# This agent never plays while training (e.g. behavioral cloning)
return None
# count steps (only when training or if we are in the evaluation worker)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
self.total_steps_counter += 1
self.current_episode_steps_counter += 1
# these types don't match: ActionInfo = ActionType
self.last_action_info = action
return self.last_action_info
def get_success_rate(self) -> float:
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed