diff --git a/agents/value_optimization_agent.py b/agents/value_optimization_agent.py index f318577..0e4596f 100644 --- a/agents/value_optimization_agent.py +++ b/agents/value_optimization_agent.py @@ -16,7 +16,9 @@ import numpy as np -from agents.agent import * +from agents.agent import Agent +from architectures.network_wrapper import NetworkWrapper +from utils import RunPhase, Signal class ValueOptimizationAgent(Agent): @@ -54,15 +56,26 @@ class ValueOptimizationAgent(Agent): def get_prediction(self, curr_state): return self.main_network.online_network.predict(self.tf_input_state(curr_state)) + def _validate_action(self, policy, action): + if np.array(action).shape != (): + raise ValueError(( + 'The exploration_policy {} returned a vector of actions ' + 'instead of a single action. ValueOptimizationAgents ' + 'require exploration policies which return a single action.' + ).format(policy.__class__.__name__)) + def choose_action(self, curr_state, phase=RunPhase.TRAIN): prediction = self.get_prediction(curr_state) actions_q_values = self.get_q_values(prediction) # choose action according to the exploration policy and the current phase (evaluating or training the agent) if phase == RunPhase.TRAIN: - action = self.exploration_policy.get_action(actions_q_values) + exploration_policy = self.exploration_policy else: - action = self.evaluation_exploration_policy.get_action(actions_q_values) + exploration_policy = self.evaluation_exploration_policy + + action = exploration_policy.get_action(actions_q_values) + self._validate_action(exploration_policy, action) # this is for bootstrapped dqn if type(actions_q_values) == list and len(actions_q_values) > 0: