mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Parallel agents fixes (#95)
* Parallel agents related bug fixes: checkpoint restore, tensorboard integration. Adding narrow networks support. Reference code for unlimited number of checkpoints
This commit is contained in:
@@ -395,7 +395,6 @@ class PPOHead(Head):
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
eps = 1e-15
|
||||
|
||||
if self.discrete_controls:
|
||||
self.actions = tf.placeholder(tf.int32, [None], name="actions")
|
||||
else:
|
||||
@@ -410,7 +409,7 @@ class PPOHead(Head):
|
||||
self.policy_mean = tf.nn.softmax(policy_values, name="policy")
|
||||
|
||||
# define the distributions for the policy and the old policy
|
||||
self.policy_distribution = tf.contrib.distributions.Categorical(probs=self.policy_mean)
|
||||
self.policy_distribution = tf.contrib.distributions.Categorical(probs=(self.policy_mean + eps))
|
||||
self.old_policy_distribution = tf.contrib.distributions.Categorical(probs=self.old_policy_mean)
|
||||
|
||||
self.output = self.policy_mean
|
||||
@@ -445,7 +444,7 @@ class PPOHead(Head):
|
||||
# calculate surrogate loss
|
||||
self.advantages = tf.placeholder(tf.float32, [None], name="advantages")
|
||||
self.target = self.advantages
|
||||
self.likelihood_ratio = self.action_probs_wrt_policy / self.action_probs_wrt_old_policy
|
||||
self.likelihood_ratio = self.action_probs_wrt_policy / (self.action_probs_wrt_old_policy + eps)
|
||||
if self.clip_likelihood_ratio_using_epsilon is not None:
|
||||
max_value = 1 + self.clip_likelihood_ratio_using_epsilon
|
||||
min_value = 1 - self.clip_likelihood_ratio_using_epsilon
|
||||
|
||||
Reference in New Issue
Block a user