From 8fc24a2bbe2fbcc5123a7c38d63e4651ab562209 Mon Sep 17 00:00:00 2001 From: Zach Dwiel Date: Fri, 16 Feb 2018 20:39:00 -0500 Subject: [PATCH] fix bc_agent --- agents/bc_agent.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/agents/bc_agent.py b/agents/bc_agent.py index e065731..70fe3e6 100644 --- a/agents/bc_agent.py +++ b/agents/bc_agent.py @@ -14,7 +14,9 @@ # limitations under the License. # -from agents.imitation_agent import * +import numpy as np + +from agents.imitation_agent import ImitationAgent # Behavioral Cloning Agent @@ -25,16 +27,13 @@ class BCAgent(ImitationAgent): def learn_from_batch(self, batch): current_states, _, actions, _, _, _ = self.extract_batch(batch) - # create the inputs for the network - input = current_states - # the targets for the network are the actions since this is supervised learning if self.env.discrete_controls: targets = np.eye(self.env.action_space_size)[[actions]] else: targets = actions - result = self.main_network.train_and_sync_networks(input, targets) + result = self.main_network.train_and_sync_networks(current_states, targets) total_loss = result[0] return total_loss