mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
change method interface: AgentInterface.emulate_act_on_trainer(transition: Transition) -> emulate_act_on_trainer(action: ActionType)
This commit is contained in:
@@ -813,6 +813,7 @@ class Agent(AgentInterface):
|
|||||||
curr_state = self.curr_state
|
curr_state = self.curr_state
|
||||||
self.last_action_info = self.choose_action(curr_state)
|
self.last_action_info = self.choose_action(curr_state)
|
||||||
|
|
||||||
|
# is it intentional that self.last_action_info is not filtered?
|
||||||
filtered_action_info = self.output_filter.filter(self.last_action_info)
|
filtered_action_info = self.output_filter.filter(self.last_action_info)
|
||||||
|
|
||||||
return filtered_action_info
|
return filtered_action_info
|
||||||
@@ -1037,7 +1038,7 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||||
def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo:
|
def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo:
|
||||||
"""
|
"""
|
||||||
This emulates the act using the transition obtained from the rollout worker on the training worker
|
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||||
in case of distributed training.
|
in case of distributed training.
|
||||||
@@ -1053,7 +1054,8 @@ class Agent(AgentInterface):
|
|||||||
self.total_steps_counter += 1
|
self.total_steps_counter += 1
|
||||||
self.current_episode_steps_counter += 1
|
self.current_episode_steps_counter += 1
|
||||||
|
|
||||||
self.last_action_info = transition.action
|
# these types don't match: ActionInfo = ActionType
|
||||||
|
self.last_action_info = action
|
||||||
|
|
||||||
return self.last_action_info
|
return self.last_action_info
|
||||||
|
|
||||||
|
|||||||
@@ -144,14 +144,11 @@ class AgentInterface(object):
|
|||||||
|
|
||||||
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||||
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||||
def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo:
|
def emulate_act_on_trainer(self, action: ActionType) -> ActionInfo:
|
||||||
"""
|
"""
|
||||||
This emulates the act using the transition obtained from the rollout worker on the training worker
|
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||||
in case of distributed training.
|
in case of distributed training.
|
||||||
Get a decision of the next action to take.
|
:return: A tuple containing the actual action
|
||||||
The action is dependent on the current state which the agent holds from resetting the environment or from
|
|
||||||
the observe function.
|
|
||||||
:return: A tuple containing the actual action and additional info on the action
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("")
|
raise NotImplementedError("")
|
||||||
|
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ class LevelManager(EnvironmentInterface):
|
|||||||
# for i in range(self.steps_limit.num_steps):
|
# for i in range(self.steps_limit.num_steps):
|
||||||
# let the agent observe the result and decide if it wants to terminate the episode
|
# let the agent observe the result and decide if it wants to terminate the episode
|
||||||
done = acting_agent.emulate_observe_on_trainer(transition)
|
done = acting_agent.emulate_observe_on_trainer(transition)
|
||||||
acting_agent.emulate_act_on_trainer(transition)
|
acting_agent.emulate_act_on_trainer(transition.action)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
self.handle_episode_ended()
|
self.handle_episode_ended()
|
||||||
|
|||||||
Reference in New Issue
Block a user