mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
merge AgentInterface.emulate_act_on_trainer and AgentInterface.act
This commit is contained in:
@@ -783,10 +783,11 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
return batches_dict
|
return batches_dict
|
||||||
|
|
||||||
def act(self) -> ActionInfo:
|
def act(self, action: Union[None, ActionType]=None) -> ActionInfo:
|
||||||
"""
|
"""
|
||||||
Given the agents current knowledge, decide on the next action to apply to the environment
|
Given the agents current knowledge, decide on the next action to apply to the environment
|
||||||
|
|
||||||
|
:param action: An action to take, overriding whatever the current policy is
|
||||||
:return: An ActionInfo object, which contains the action and any additional info from the action decision process
|
:return: An ActionInfo object, which contains the action and any additional info from the action decision process
|
||||||
"""
|
"""
|
||||||
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
|
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
|
||||||
@@ -799,19 +800,23 @@ class Agent(AgentInterface):
|
|||||||
self.current_episode_steps_counter += 1
|
self.current_episode_steps_counter += 1
|
||||||
|
|
||||||
# decide on the action
|
# decide on the action
|
||||||
if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions:
|
if action is None:
|
||||||
# random action
|
if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions:
|
||||||
self.last_action_info = self.spaces.action.sample_with_info()
|
# random action
|
||||||
else:
|
action = self.spaces.action.sample_with_info()
|
||||||
# informed action
|
|
||||||
if self.pre_network_filter is not None:
|
|
||||||
# before choosing an action, first use the pre_network_filter to filter out the current state
|
|
||||||
update_filter_internal_state = self.phase is not RunPhase.TEST
|
|
||||||
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
curr_state = self.curr_state
|
# informed action
|
||||||
self.last_action_info = self.choose_action(curr_state)
|
if self.pre_network_filter is not None:
|
||||||
|
# before choosing an action, first use the pre_network_filter to filter out the current state
|
||||||
|
update_filter_internal_state = self.phase is not RunPhase.TEST
|
||||||
|
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
|
||||||
|
|
||||||
|
else:
|
||||||
|
curr_state = self.curr_state
|
||||||
|
action = self.choose_action(curr_state)
|
||||||
|
assert isinstance(action, ActionInfo)
|
||||||
|
|
||||||
|
self.last_action_info = action
|
||||||
|
|
||||||
# is it intentional that self.last_action_info is not filtered?
|
# is it intentional that self.last_action_info is not filtered?
|
||||||
filtered_action_info = self.output_filter.filter(self.last_action_info)
|
filtered_action_info = self.output_filter.filter(self.last_action_info)
|
||||||
@@ -1036,29 +1041,6 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
return transition.game_over
|
return transition.game_over
|
||||||
|
|
||||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
|
||||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
|
||||||
def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo:
|
|
||||||
"""
|
|
||||||
This emulates the act using the transition obtained from the rollout worker on the training worker
|
|
||||||
in case of distributed training.
|
|
||||||
Given the agents current knowledge, decide on the next action to apply to the environment
|
|
||||||
:return: an action and a dictionary containing any additional info from the action decision process
|
|
||||||
"""
|
|
||||||
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
|
|
||||||
# This agent never plays while training (e.g. behavioral cloning)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# count steps (only when training or if we are in the evaluation worker)
|
|
||||||
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
|
|
||||||
self.total_steps_counter += 1
|
|
||||||
self.current_episode_steps_counter += 1
|
|
||||||
|
|
||||||
# these types don't match: ActionInfo = ActionType
|
|
||||||
self.last_action_info = action
|
|
||||||
|
|
||||||
return self.last_action_info
|
|
||||||
|
|
||||||
def get_success_rate(self) -> float:
|
def get_success_rate(self) -> float:
|
||||||
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||||
|
|
||||||
|
|||||||
@@ -142,16 +142,6 @@ class AgentInterface(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("")
|
raise NotImplementedError("")
|
||||||
|
|
||||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
|
||||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
|
||||||
def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo:
|
|
||||||
"""
|
|
||||||
This emulates the act using the transition obtained from the rollout worker on the training worker
|
|
||||||
in case of distributed training.
|
|
||||||
:return: A tuple containing the actual action
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("")
|
|
||||||
|
|
||||||
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
|
||||||
"""
|
"""
|
||||||
Collect all of agent savers
|
Collect all of agent savers
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ class LevelManager(EnvironmentInterface):
|
|||||||
# for i in range(self.steps_limit.num_steps):
|
# for i in range(self.steps_limit.num_steps):
|
||||||
# let the agent observe the result and decide if it wants to terminate the episode
|
# let the agent observe the result and decide if it wants to terminate the episode
|
||||||
done = acting_agent.emulate_observe_on_trainer(transition)
|
done = acting_agent.emulate_observe_on_trainer(transition)
|
||||||
acting_agent.emulate_act_on_trainer(transition.action)
|
acting_agent.act(transition.action)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
self.handle_episode_ended()
|
self.handle_episode_ended()
|
||||||
|
|||||||
Reference in New Issue
Block a user