1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

imitation related bug fixes

This commit is contained in:
itaicaspi-intel
2018-09-12 14:54:33 +03:00
parent a9bd1047c4
commit 171fe97a3a
7 changed files with 21 additions and 22 deletions

View File

@@ -29,7 +29,6 @@ from rl_coach.spaces import DiscreteActionSpace
class ImitationAgent(Agent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
super().__init__(agent_parameters, parent)
self.imitation = True
def extract_action_values(self, prediction):
@@ -41,18 +40,9 @@ class ImitationAgent(Agent):
# get action values and extract the best action from it
action_values = self.extract_action_values(prediction)
if type(self.spaces.action) == DiscreteActionSpace:
# DISCRETE
self.exploration_policy.phase = RunPhase.TEST
action = self.exploration_policy.get_action(action_values)
action_info = ActionInfo(action=action,
action_probability=action_values[action])
else:
# CONTINUOUS
action = action_values
action_info = ActionInfo(action=action)
self.exploration_policy.change_phase(RunPhase.TEST)
action = self.exploration_policy.get_action(action_values)
action_info = ActionInfo(action=action)
return action_info