mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
fix qr_dqn_agent
This commit is contained in:
@@ -101,9 +101,7 @@ class ActorCriticAgent(PolicyOptimizationAgent):
|
|||||||
actions = np.expand_dims(actions, -1)
|
actions = np.expand_dims(actions, -1)
|
||||||
|
|
||||||
# train
|
# train
|
||||||
inputs = copy.copy(current_states)
|
result = self.main_network.online_network.accumulate_gradients({**current_states, 'output_1_0': actions},
|
||||||
inputs['output_1_0'] = actions
|
|
||||||
result = self.main_network.online_network.accumulate_gradients(inputs,
|
|
||||||
[state_value_head_targets, action_advantages])
|
[state_value_head_targets, action_advantages])
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
|
|||||||
@@ -56,8 +56,11 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
|
|||||||
quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]]
|
quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]]
|
||||||
|
|
||||||
# train
|
# train
|
||||||
result = self.main_network.train_and_sync_networks([current_states, actions_locations, quantile_midpoints], TD_targets)
|
result = self.main_network.train_and_sync_networks({
|
||||||
|
**current_states,
|
||||||
|
'output_0_0': actions_locations,
|
||||||
|
'output_0_1': quantile_midpoints,
|
||||||
|
}, TD_targets)
|
||||||
total_loss = result[0]
|
total_loss = result[0]
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user