mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
fix more agents
This commit is contained in:
@@ -31,12 +31,7 @@ class ImitationAgent(Agent):
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
# convert to batch so we can run it through the network
|
||||
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
||||
if self.tp.agent.use_measurements:
|
||||
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
|
||||
prediction = self.main_network.online_network.predict([observation, measurements])
|
||||
else:
|
||||
prediction = self.main_network.online_network.predict(observation)
|
||||
prediction = self.main_network.online_network.predict(self.tf_input_state(curr_state))
|
||||
|
||||
# get action values and extract the best action from it
|
||||
action_values = self.extract_action_values(prediction)
|
||||
|
||||
Reference in New Issue
Block a user