1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

merge AgentInterface.emulate_act_on_trainer and AgentInterface.act

This commit is contained in:
zach dwiel
2019-04-05 11:49:09 -04:00
committed by Zach Dwiel
parent f2fead57e5
commit f8741522e4
3 changed files with 19 additions and 47 deletions

View File

@@ -783,10 +783,11 @@ class Agent(AgentInterface):
return batches_dict
def act(self) -> ActionInfo:
def act(self, action: Union[None, ActionType]=None) -> ActionInfo:
"""
Given the agents current knowledge, decide on the next action to apply to the environment
:param action: An action to take, overriding whatever the current policy is
:return: An ActionInfo object, which contains the action and any additional info from the action decision process
"""
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
@@ -799,9 +800,10 @@ class Agent(AgentInterface):
self.current_episode_steps_counter += 1
# decide on the action
if action is None:
if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions:
# random action
self.last_action_info = self.spaces.action.sample_with_info()
action = self.spaces.action.sample_with_info()
else:
# informed action
if self.pre_network_filter is not None:
@@ -811,7 +813,10 @@ class Agent(AgentInterface):
else:
curr_state = self.curr_state
self.last_action_info = self.choose_action(curr_state)
action = self.choose_action(curr_state)
assert isinstance(action, ActionInfo)
self.last_action_info = action
# is it intentional that self.last_action_info is not filtered?
filtered_action_info = self.output_filter.filter(self.last_action_info)
@@ -1036,29 +1041,6 @@ class Agent(AgentInterface):
return transition.game_over
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given the agents current knowledge, decide on the next action to apply to the environment
:return: an action and a dictionary containing any additional info from the action decision process
"""
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
# This agent never plays while training (e.g. behavioral cloning)
return None
# count steps (only when training or if we are in the evaluation worker)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
self.total_steps_counter += 1
self.current_episode_steps_counter += 1
# these types don't match: ActionInfo = ActionType
self.last_action_info = action
return self.last_action_info
def get_success_rate(self) -> float:
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed

View File

@@ -142,16 +142,6 @@ class AgentInterface(object):
"""
raise NotImplementedError("")
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
:return: A tuple containing the actual action
"""
raise NotImplementedError("")
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
"""
Collect all of agent savers

View File

@@ -313,7 +313,7 @@ class LevelManager(EnvironmentInterface):
# for i in range(self.steps_limit.num_steps):
# let the agent observe the result and decide if it wants to terminate the episode
done = acting_agent.emulate_observe_on_trainer(transition)
acting_agent.emulate_act_on_trainer(transition.action)
acting_agent.act(transition.action)
if done:
self.handle_episode_ended()