From 913ab75e8a78c4664c7858f860a73ad2ccbed9b8 Mon Sep 17 00:00:00 2001 From: Itai Caspi Date: Tue, 31 Oct 2017 10:49:50 +0200 Subject: [PATCH] bug fix - preventing crashes when the probability of one of the actions is 0 in the policy head --- agents/actor_critic_agent.py | 2 +- agents/policy_gradients_agent.py | 2 +- architectures/tensorflow_components/heads.py | 3 ++- utils.py | 1 + 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/agents/actor_critic_agent.py b/agents/actor_critic_agent.py index 279075f..ed35ee6 100644 --- a/agents/actor_critic_agent.py +++ b/agents/actor_critic_agent.py @@ -121,7 +121,7 @@ class ActorCriticAgent(PolicyOptimizationAgent): else: action = np.argmax(action_probabilities) action_info = {"action_probability": action_probabilities[action], "state_value": state_value} - self.entropy.add_sample(-np.sum(action_probabilities * np.log(action_probabilities))) + self.entropy.add_sample(-np.sum(action_probabilities * np.log(action_probabilities + eps))) else: # CONTINUOUS state_value, action_values_mean, action_values_std = self.main_network.online_network.predict(observation) diff --git a/agents/policy_gradients_agent.py b/agents/policy_gradients_agent.py index bf873d1..11cef75 100644 --- a/agents/policy_gradients_agent.py +++ b/agents/policy_gradients_agent.py @@ -73,7 +73,7 @@ class PolicyGradientsAgent(PolicyOptimizationAgent): else: action = np.argmax(action_values) action_value = {"action_probability": action_values[action]} - self.entropy.add_sample(-np.sum(action_values * np.log(action_values))) + self.entropy.add_sample(-np.sum(action_values * np.log(action_values + eps))) else: # CONTINUOUS result = self.main_network.online_network.predict(observation) diff --git a/architectures/tensorflow_components/heads.py b/architectures/tensorflow_components/heads.py index 4d8ec17..ab2bc2c 100644 --- a/architectures/tensorflow_components/heads.py +++ b/architectures/tensorflow_components/heads.py @@ -177,7 +177,8 @@ class PolicyHead(Head): self.policy_mean = tf.nn.softmax(policy_values, name="policy") # define the distributions for the policy and the old policy - self.policy_distribution = tf.contrib.distributions.Categorical(probs=self.policy_mean) + # (the + eps is to prevent probability 0 which will cause the log later on to be -inf) + self.policy_distribution = tf.contrib.distributions.Categorical(probs=(self.policy_mean + eps)) self.output = self.policy_mean else: # mean diff --git a/utils.py b/utils.py index c1fe872..c660724 100644 --- a/utils.py +++ b/utils.py @@ -23,6 +23,7 @@ from subprocess import call, Popen killed_processes = [] +eps = np.finfo(np.float32).eps class Enum(object): def __init__(self):