mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
fix clipped ppo
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -53,7 +53,7 @@ class PPOAgent(ActorCriticAgent):
|
||||
# * Found not to have any impact *
|
||||
# current_states_with_timestep = self.concat_state_and_timestep(batch)
|
||||
|
||||
current_state_values = self.critic_network.online_network.predict([current_states]).squeeze()
|
||||
current_state_values = self.critic_network.online_network.predict(current_state).squeeze()
|
||||
|
||||
# calculate advantages
|
||||
advantages = []
|
||||
@@ -105,11 +105,11 @@ class PPOAgent(ActorCriticAgent):
|
||||
current_states_batch = current_states[i * batch_size:(i + 1) * batch_size]
|
||||
total_return_batch = total_return[i * batch_size:(i + 1) * batch_size]
|
||||
old_policy_values = force_list(self.critic_network.target_network.predict(
|
||||
[current_states_batch]).squeeze())
|
||||
current_states_batch).squeeze())
|
||||
if self.critic_network.online_network.optimizer_type != 'LBFGS':
|
||||
targets = total_return_batch
|
||||
else:
|
||||
current_values = self.critic_network.online_network.predict([current_states_batch])
|
||||
current_values = self.critic_network.online_network.predict(current_states_batch)
|
||||
targets = current_values * (1 - mix_fraction) + total_return_batch * mix_fraction
|
||||
|
||||
value_loss = self.critic_network.online_network.\
|
||||
|
||||
Reference in New Issue
Block a user