mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
provide a helpful error message in the event that an exploration policy returns a vector of actions instead of a single action during value optimization agent
This commit is contained in:
@@ -16,7 +16,9 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agents.agent import *
|
||||
from agents.agent import Agent
|
||||
from architectures.network_wrapper import NetworkWrapper
|
||||
from utils import RunPhase, Signal
|
||||
|
||||
|
||||
class ValueOptimizationAgent(Agent):
|
||||
@@ -54,15 +56,26 @@ class ValueOptimizationAgent(Agent):
|
||||
def get_prediction(self, curr_state):
|
||||
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
|
||||
|
||||
def _validate_action(self, policy, action):
|
||||
if np.array(action).shape != ():
|
||||
raise ValueError((
|
||||
'The exploration_policy {} returned a vector of actions '
|
||||
'instead of a single action. ValueOptimizationAgents '
|
||||
'require exploration policies which return a single action.'
|
||||
).format(policy.__class__.__name__))
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
prediction = self.get_prediction(curr_state)
|
||||
actions_q_values = self.get_q_values(prediction)
|
||||
|
||||
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
||||
if phase == RunPhase.TRAIN:
|
||||
action = self.exploration_policy.get_action(actions_q_values)
|
||||
exploration_policy = self.exploration_policy
|
||||
else:
|
||||
action = self.evaluation_exploration_policy.get_action(actions_q_values)
|
||||
exploration_policy = self.evaluation_exploration_policy
|
||||
|
||||
action = exploration_policy.get_action(actions_q_values)
|
||||
self._validate_action(exploration_policy, action)
|
||||
|
||||
# this is for bootstrapped dqn
|
||||
if type(actions_q_values) == list and len(actions_q_values) > 0:
|
||||
|
||||
Reference in New Issue
Block a user