1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

fix bc_agent

This commit is contained in:
Zach Dwiel
2018-02-16 20:39:00 -05:00
parent d8f5a35013
commit 8fc24a2bbe

View File

@@ -14,7 +14,9 @@
# limitations under the License. # limitations under the License.
# #
from agents.imitation_agent import * import numpy as np
from agents.imitation_agent import ImitationAgent
# Behavioral Cloning Agent # Behavioral Cloning Agent
@@ -25,16 +27,13 @@ class BCAgent(ImitationAgent):
def learn_from_batch(self, batch): def learn_from_batch(self, batch):
current_states, _, actions, _, _, _ = self.extract_batch(batch) current_states, _, actions, _, _, _ = self.extract_batch(batch)
# create the inputs for the network
input = current_states
# the targets for the network are the actions since this is supervised learning # the targets for the network are the actions since this is supervised learning
if self.env.discrete_controls: if self.env.discrete_controls:
targets = np.eye(self.env.action_space_size)[[actions]] targets = np.eye(self.env.action_space_size)[[actions]]
else: else:
targets = actions targets = actions
result = self.main_network.train_and_sync_networks(input, targets) result = self.main_network.train_and_sync_networks(current_states, targets)
total_loss = result[0] total_loss = result[0]
return total_loss return total_loss