diff --git a/agents/actor_critic_agent.py b/agents/actor_critic_agent.py index 279075f..ed35ee6 100644 --- a/agents/actor_critic_agent.py +++ b/agents/actor_critic_agent.py @@ -121,7 +121,7 @@ class ActorCriticAgent(PolicyOptimizationAgent): else: action = np.argmax(action_probabilities) action_info = {"action_probability": action_probabilities[action], "state_value": state_value} - self.entropy.add_sample(-np.sum(action_probabilities * np.log(action_probabilities))) + self.entropy.add_sample(-np.sum(action_probabilities * np.log(action_probabilities + eps))) else: # CONTINUOUS state_value, action_values_mean, action_values_std = self.main_network.online_network.predict(observation) diff --git a/agents/policy_gradients_agent.py b/agents/policy_gradients_agent.py index bf873d1..11cef75 100644 --- a/agents/policy_gradients_agent.py +++ b/agents/policy_gradients_agent.py @@ -73,7 +73,7 @@ class PolicyGradientsAgent(PolicyOptimizationAgent): else: action = np.argmax(action_values) action_value = {"action_probability": action_values[action]} - self.entropy.add_sample(-np.sum(action_values * np.log(action_values))) + self.entropy.add_sample(-np.sum(action_values * np.log(action_values + eps))) else: # CONTINUOUS result = self.main_network.online_network.predict(observation) diff --git a/architectures/tensorflow_components/heads.py b/architectures/tensorflow_components/heads.py index 4d8ec17..ab2bc2c 100644 --- a/architectures/tensorflow_components/heads.py +++ b/architectures/tensorflow_components/heads.py @@ -177,7 +177,8 @@ class PolicyHead(Head): self.policy_mean = tf.nn.softmax(policy_values, name="policy") # define the distributions for the policy and the old policy - self.policy_distribution = tf.contrib.distributions.Categorical(probs=self.policy_mean) + # (the + eps is to prevent probability 0 which will cause the log later on to be -inf) + self.policy_distribution = tf.contrib.distributions.Categorical(probs=(self.policy_mean + eps)) self.output = self.policy_mean else: # mean diff --git a/utils.py b/utils.py index c1fe872..c660724 100644 --- a/utils.py +++ b/utils.py @@ -23,6 +23,7 @@ from subprocess import call, Popen killed_processes = [] +eps = np.finfo(np.float32).eps class Enum(object): def __init__(self):