mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
fix more agents
This commit is contained in:
@@ -64,17 +64,16 @@ class PolicyGradientsAgent(PolicyOptimizationAgent):
|
||||
self.returns_mean.add_sample(np.mean(total_returns))
|
||||
self.returns_variance.add_sample(np.std(total_returns))
|
||||
|
||||
result = self.main_network.online_network.accumulate_gradients([current_states, actions], targets)
|
||||
result = self.main_network.online_network.accumulate_gradients({**current_states, 'output_0_0': actions}, targets)
|
||||
total_loss = result[0]
|
||||
|
||||
return total_loss
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
# convert to batch so we can run it through the network
|
||||
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
||||
if self.env.discrete_controls:
|
||||
# DISCRETE
|
||||
action_values = self.main_network.online_network.predict(observation).squeeze()
|
||||
action_values = self.main_network.online_network.predict(self.tf_input_state(curr_state)).squeeze()
|
||||
if phase == RunPhase.TRAIN:
|
||||
action = self.exploration_policy.get_action(action_values)
|
||||
else:
|
||||
@@ -83,7 +82,7 @@ class PolicyGradientsAgent(PolicyOptimizationAgent):
|
||||
self.entropy.add_sample(-np.sum(action_values * np.log(action_values + eps)))
|
||||
else:
|
||||
# CONTINUOUS
|
||||
result = self.main_network.online_network.predict(observation)
|
||||
result = self.main_network.online_network.predict(self.tf_input_state(curr_state))
|
||||
action_values = result[0].squeeze()
|
||||
if phase == RunPhase.TRAIN:
|
||||
action = self.exploration_policy.get_action(action_values)
|
||||
|
||||
Reference in New Issue
Block a user