1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

fix more agents

This commit is contained in:
Zach Dwiel
2018-02-16 20:06:51 -05:00
parent 98f57a0d87
commit 8248caf35e
6 changed files with 52 additions and 42 deletions

View File

@@ -31,12 +31,7 @@ class ImitationAgent(Agent):
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
if self.tp.agent.use_measurements:
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
prediction = self.main_network.online_network.predict([observation, measurements])
else:
prediction = self.main_network.online_network.predict(observation)
prediction = self.main_network.online_network.predict(self.tf_input_state(curr_state))
# get action values and extract the best action from it
action_values = self.extract_action_values(prediction)