1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

fix bc_agent

This commit is contained in:
Zach Dwiel
2018-02-16 20:39:00 -05:00
parent d8f5a35013
commit 8fc24a2bbe

View File

@@ -14,7 +14,9 @@
# limitations under the License.
#
from agents.imitation_agent import *
import numpy as np
from agents.imitation_agent import ImitationAgent
# Behavioral Cloning Agent
@@ -25,16 +27,13 @@ class BCAgent(ImitationAgent):
def learn_from_batch(self, batch):
current_states, _, actions, _, _, _ = self.extract_batch(batch)
# create the inputs for the network
input = current_states
# the targets for the network are the actions since this is supervised learning
if self.env.discrete_controls:
targets = np.eye(self.env.action_space_size)[[actions]]
else:
targets = actions
result = self.main_network.train_and_sync_networks(input, targets)
result = self.main_network.train_and_sync_networks(current_states, targets)
total_loss = result[0]
return total_loss