mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
fix bc_agent
This commit is contained in:
@@ -14,7 +14,9 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from agents.imitation_agent import *
|
||||
import numpy as np
|
||||
|
||||
from agents.imitation_agent import ImitationAgent
|
||||
|
||||
|
||||
# Behavioral Cloning Agent
|
||||
@@ -25,16 +27,13 @@ class BCAgent(ImitationAgent):
|
||||
def learn_from_batch(self, batch):
|
||||
current_states, _, actions, _, _, _ = self.extract_batch(batch)
|
||||
|
||||
# create the inputs for the network
|
||||
input = current_states
|
||||
|
||||
# the targets for the network are the actions since this is supervised learning
|
||||
if self.env.discrete_controls:
|
||||
targets = np.eye(self.env.action_space_size)[[actions]]
|
||||
else:
|
||||
targets = actions
|
||||
|
||||
result = self.main_network.train_and_sync_networks(input, targets)
|
||||
result = self.main_network.train_and_sync_networks(current_states, targets)
|
||||
total_loss = result[0]
|
||||
|
||||
return total_loss
|
||||
|
||||
Reference in New Issue
Block a user