mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
imitation related bug fixes
This commit is contained in:
@@ -29,7 +29,6 @@ from rl_coach.spaces import DiscreteActionSpace
|
||||
class ImitationAgent(Agent):
|
||||
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
||||
super().__init__(agent_parameters, parent)
|
||||
|
||||
self.imitation = True
|
||||
|
||||
def extract_action_values(self, prediction):
|
||||
@@ -41,18 +40,9 @@ class ImitationAgent(Agent):
|
||||
|
||||
# get action values and extract the best action from it
|
||||
action_values = self.extract_action_values(prediction)
|
||||
if type(self.spaces.action) == DiscreteActionSpace:
|
||||
# DISCRETE
|
||||
self.exploration_policy.phase = RunPhase.TEST
|
||||
action = self.exploration_policy.get_action(action_values)
|
||||
|
||||
action_info = ActionInfo(action=action,
|
||||
action_probability=action_values[action])
|
||||
else:
|
||||
# CONTINUOUS
|
||||
action = action_values
|
||||
|
||||
action_info = ActionInfo(action=action)
|
||||
self.exploration_policy.change_phase(RunPhase.TEST)
|
||||
action = self.exploration_policy.get_action(action_values)
|
||||
action_info = ActionInfo(action=action)
|
||||
|
||||
return action_info
|
||||
|
||||
|
||||
Reference in New Issue
Block a user