mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
multiple bug fixes in dealing with measurements + CartPole_DFP preset (#92)
This commit is contained in:
@@ -31,7 +31,7 @@ class DFPAgent(Agent):
|
||||
|
||||
# create the inputs for the network
|
||||
input = current_states
|
||||
input.append(np.repeat(np.expand_dims(self.current_goal, 0), self.tp.batch_size, 0))
|
||||
input['goal'] = np.repeat(np.expand_dims(self.current_goal, 0), self.tp.batch_size, 0)
|
||||
|
||||
# get the current outputs of the network
|
||||
targets = self.main_network.online_network.predict(input)
|
||||
@@ -40,7 +40,7 @@ class DFPAgent(Agent):
|
||||
for i in range(self.tp.batch_size):
|
||||
targets[i, actions[i]] = batch[i].info['future_measurements'].flatten()
|
||||
|
||||
result = self.main_network.train_and_sync_networks(current_states, targets)
|
||||
result = self.main_network.train_and_sync_networks(input, targets)
|
||||
total_loss = result[0]
|
||||
|
||||
return total_loss
|
||||
@@ -52,7 +52,10 @@ class DFPAgent(Agent):
|
||||
goal = np.expand_dims(self.current_goal, 0)
|
||||
|
||||
# predict the future measurements
|
||||
measurements_future_prediction = self.main_network.online_network.predict([observation, measurements, goal])[0]
|
||||
measurements_future_prediction = self.main_network.online_network.predict({
|
||||
"observation": observation,
|
||||
"measurements": measurements,
|
||||
"goal": goal})[0]
|
||||
action_values = np.zeros((self.action_space_size,))
|
||||
num_steps_used_for_objective = len(self.tp.agent.future_measurements_weights)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user