mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
fix bc_agent
This commit is contained in:
@@ -14,7 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from agents.imitation_agent import *
|
import numpy as np
|
||||||
|
|
||||||
|
from agents.imitation_agent import ImitationAgent
|
||||||
|
|
||||||
|
|
||||||
# Behavioral Cloning Agent
|
# Behavioral Cloning Agent
|
||||||
@@ -25,16 +27,13 @@ class BCAgent(ImitationAgent):
|
|||||||
def learn_from_batch(self, batch):
|
def learn_from_batch(self, batch):
|
||||||
current_states, _, actions, _, _, _ = self.extract_batch(batch)
|
current_states, _, actions, _, _, _ = self.extract_batch(batch)
|
||||||
|
|
||||||
# create the inputs for the network
|
|
||||||
input = current_states
|
|
||||||
|
|
||||||
# the targets for the network are the actions since this is supervised learning
|
# the targets for the network are the actions since this is supervised learning
|
||||||
if self.env.discrete_controls:
|
if self.env.discrete_controls:
|
||||||
targets = np.eye(self.env.action_space_size)[[actions]]
|
targets = np.eye(self.env.action_space_size)[[actions]]
|
||||||
else:
|
else:
|
||||||
targets = actions
|
targets = actions
|
||||||
|
|
||||||
result = self.main_network.train_and_sync_networks(input, targets)
|
result = self.main_network.train_and_sync_networks(current_states, targets)
|
||||||
total_loss = result[0]
|
total_loss = result[0]
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|||||||
Reference in New Issue
Block a user