mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
provide a helpful error message in the event that an exploration policy returns a vector of actions instead of a single action during value optimization agent
This commit is contained in:
@@ -16,7 +16,9 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from agents.agent import *
|
from agents.agent import Agent
|
||||||
|
from architectures.network_wrapper import NetworkWrapper
|
||||||
|
from utils import RunPhase, Signal
|
||||||
|
|
||||||
|
|
||||||
class ValueOptimizationAgent(Agent):
|
class ValueOptimizationAgent(Agent):
|
||||||
@@ -54,15 +56,26 @@ class ValueOptimizationAgent(Agent):
|
|||||||
def get_prediction(self, curr_state):
|
def get_prediction(self, curr_state):
|
||||||
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
|
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
|
||||||
|
|
||||||
|
def _validate_action(self, policy, action):
|
||||||
|
if np.array(action).shape != ():
|
||||||
|
raise ValueError((
|
||||||
|
'The exploration_policy {} returned a vector of actions '
|
||||||
|
'instead of a single action. ValueOptimizationAgents '
|
||||||
|
'require exploration policies which return a single action.'
|
||||||
|
).format(policy.__class__.__name__))
|
||||||
|
|
||||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||||
prediction = self.get_prediction(curr_state)
|
prediction = self.get_prediction(curr_state)
|
||||||
actions_q_values = self.get_q_values(prediction)
|
actions_q_values = self.get_q_values(prediction)
|
||||||
|
|
||||||
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
||||||
if phase == RunPhase.TRAIN:
|
if phase == RunPhase.TRAIN:
|
||||||
action = self.exploration_policy.get_action(actions_q_values)
|
exploration_policy = self.exploration_policy
|
||||||
else:
|
else:
|
||||||
action = self.evaluation_exploration_policy.get_action(actions_q_values)
|
exploration_policy = self.evaluation_exploration_policy
|
||||||
|
|
||||||
|
action = exploration_policy.get_action(actions_q_values)
|
||||||
|
self._validate_action(exploration_policy, action)
|
||||||
|
|
||||||
# this is for bootstrapped dqn
|
# this is for bootstrapped dqn
|
||||||
if type(actions_q_values) == list and len(actions_q_values) > 0:
|
if type(actions_q_values) == list and len(actions_q_values) > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user