1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

merge AgentInterface.emulate_act_on_trainer and AgentInterface.act

This commit is contained in:
zach dwiel
2019-04-05 11:49:09 -04:00
committed by Zach Dwiel
parent f2fead57e5
commit f8741522e4
3 changed files with 19 additions and 47 deletions

View File

@@ -783,10 +783,11 @@ class Agent(AgentInterface):
return batches_dict return batches_dict
def act(self) -> ActionInfo: def act(self, action: Union[None, ActionType]=None) -> ActionInfo:
""" """
Given the agents current knowledge, decide on the next action to apply to the environment Given the agents current knowledge, decide on the next action to apply to the environment
:param action: An action to take, overriding whatever the current policy is
:return: An ActionInfo object, which contains the action and any additional info from the action decision process :return: An ActionInfo object, which contains the action and any additional info from the action decision process
""" """
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0: if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
@@ -799,19 +800,23 @@ class Agent(AgentInterface):
self.current_episode_steps_counter += 1 self.current_episode_steps_counter += 1
# decide on the action # decide on the action
if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions: if action is None:
# random action if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions:
self.last_action_info = self.spaces.action.sample_with_info() # random action
else: action = self.spaces.action.sample_with_info()
# informed action
if self.pre_network_filter is not None:
# before choosing an action, first use the pre_network_filter to filter out the current state
update_filter_internal_state = self.phase is not RunPhase.TEST
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
else: else:
curr_state = self.curr_state # informed action
self.last_action_info = self.choose_action(curr_state) if self.pre_network_filter is not None:
# before choosing an action, first use the pre_network_filter to filter out the current state
update_filter_internal_state = self.phase is not RunPhase.TEST
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
else:
curr_state = self.curr_state
action = self.choose_action(curr_state)
assert isinstance(action, ActionInfo)
self.last_action_info = action
# is it intentional that self.last_action_info is not filtered? # is it intentional that self.last_action_info is not filtered?
filtered_action_info = self.output_filter.filter(self.last_action_info) filtered_action_info = self.output_filter.filter(self.last_action_info)
@@ -1036,29 +1041,6 @@ class Agent(AgentInterface):
return transition.game_over return transition.game_over
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given the agents current knowledge, decide on the next action to apply to the environment
:return: an action and a dictionary containing any additional info from the action decision process
"""
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
# This agent never plays while training (e.g. behavioral cloning)
return None
# count steps (only when training or if we are in the evaluation worker)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
self.total_steps_counter += 1
self.current_episode_steps_counter += 1
# these types don't match: ActionInfo = ActionType
self.last_action_info = action
return self.last_action_info
def get_success_rate(self) -> float: def get_success_rate(self) -> float:
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed

View File

@@ -142,16 +142,6 @@ class AgentInterface(object):
""" """
raise NotImplementedError("") raise NotImplementedError("")
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
:return: A tuple containing the actual action
"""
raise NotImplementedError("")
def collect_savers(self, parent_path_suffix: str) -> SaverCollection: def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
""" """
Collect all of agent savers Collect all of agent savers

View File

@@ -313,7 +313,7 @@ class LevelManager(EnvironmentInterface):
# for i in range(self.steps_limit.num_steps): # for i in range(self.steps_limit.num_steps):
# let the agent observe the result and decide if it wants to terminate the episode # let the agent observe the result and decide if it wants to terminate the episode
done = acting_agent.emulate_observe_on_trainer(transition) done = acting_agent.emulate_observe_on_trainer(transition)
acting_agent.emulate_act_on_trainer(transition.action) acting_agent.act(transition.action)
if done: if done:
self.handle_episode_ended() self.handle_episode_ended()