1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

fix more agents

This commit is contained in:
Zach Dwiel
2018-02-16 20:06:51 -05:00
parent 98f57a0d87
commit 8248caf35e
6 changed files with 52 additions and 42 deletions

View File

@@ -64,17 +64,16 @@ class PolicyGradientsAgent(PolicyOptimizationAgent):
self.returns_mean.add_sample(np.mean(total_returns))
self.returns_variance.add_sample(np.std(total_returns))
result = self.main_network.online_network.accumulate_gradients([current_states, actions], targets)
result = self.main_network.online_network.accumulate_gradients({**current_states, 'output_0_0': actions}, targets)
total_loss = result[0]
return total_loss
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
if self.env.discrete_controls:
# DISCRETE
action_values = self.main_network.online_network.predict(observation).squeeze()
action_values = self.main_network.online_network.predict(self.tf_input_state(curr_state)).squeeze()
if phase == RunPhase.TRAIN:
action = self.exploration_policy.get_action(action_values)
else:
@@ -83,7 +82,7 @@ class PolicyGradientsAgent(PolicyOptimizationAgent):
self.entropy.add_sample(-np.sum(action_values * np.log(action_values + eps)))
else:
# CONTINUOUS
result = self.main_network.online_network.predict(observation)
result = self.main_network.online_network.predict(self.tf_input_state(curr_state))
action_values = result[0].squeeze()
if phase == RunPhase.TRAIN:
action = self.exploration_policy.get_action(action_values)