1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

provide a helpful error message in the event that an exploration policy returns a vector of actions instead of a single action during value optimization agent

This commit is contained in:
Zach Dwiel
2018-01-11 11:43:37 -05:00
parent 40e5c628c6
commit fff8c8f568

View File

@@ -16,7 +16,9 @@
import numpy as np
from agents.agent import *
from agents.agent import Agent
from architectures.network_wrapper import NetworkWrapper
from utils import RunPhase, Signal
class ValueOptimizationAgent(Agent):
@@ -54,15 +56,26 @@ class ValueOptimizationAgent(Agent):
def get_prediction(self, curr_state):
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
def _validate_action(self, policy, action):
if np.array(action).shape != ():
raise ValueError((
'The exploration_policy {} returned a vector of actions '
'instead of a single action. ValueOptimizationAgents '
'require exploration policies which return a single action.'
).format(policy.__class__.__name__))
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
prediction = self.get_prediction(curr_state)
actions_q_values = self.get_q_values(prediction)
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
if phase == RunPhase.TRAIN:
action = self.exploration_policy.get_action(actions_q_values)
exploration_policy = self.exploration_policy
else:
action = self.evaluation_exploration_policy.get_action(actions_q_values)
exploration_policy = self.evaluation_exploration_policy
action = exploration_policy.get_action(actions_q_values)
self._validate_action(exploration_policy, action)
# this is for bootstrapped dqn
if type(actions_q_values) == list and len(actions_q_values) > 0: