From 9106b692271592683f73b7290c4c240dcab6a287 Mon Sep 17 00:00:00 2001 From: Guy Jacob Date: Thu, 6 May 2021 18:02:02 +0300 Subject: [PATCH] Add is_on_policy property to agents (#480) --- rl_coach/agents/acer_agent.py | 5 ++++- rl_coach/agents/actor_critic_agent.py | 4 ++++ rl_coach/agents/agent_interface.py | 5 +++++ rl_coach/agents/bc_agent.py | 4 ++++ rl_coach/agents/bootstrapped_dqn_agent.py | 4 ++++ rl_coach/agents/categorical_dqn_agent.py | 4 ++++ rl_coach/agents/clipped_ppo_agent.py | 4 ++++ rl_coach/agents/ddpg_agent.py | 4 ++++ rl_coach/agents/dfp_agent.py | 6 ++++++ rl_coach/agents/dqn_agent.py | 4 ++++ rl_coach/agents/imitation_agent.py | 4 ++++ rl_coach/agents/mmc_agent.py | 4 ++++ rl_coach/agents/n_step_q_agent.py | 4 ++++ rl_coach/agents/naf_agent.py | 4 ++++ rl_coach/agents/nec_agent.py | 3 +++ rl_coach/agents/pal_agent.py | 4 ++++ rl_coach/agents/policy_gradients_agent.py | 4 ++++ rl_coach/agents/ppo_agent.py | 4 ++++ rl_coach/agents/qr_dqn_agent.py | 4 ++++ rl_coach/agents/soft_actor_critic_agent.py | 4 ++++ rl_coach/agents/td3_agent.py | 4 ++++ 21 files changed, 86 insertions(+), 1 deletion(-) diff --git a/rl_coach/agents/acer_agent.py b/rl_coach/agents/acer_agent.py index 94e76b6..337f606 100644 --- a/rl_coach/agents/acer_agent.py +++ b/rl_coach/agents/acer_agent.py @@ -111,8 +111,11 @@ class ACERAgent(PolicyOptimizationAgent): self.V_Values = self.register_signal('Values') self.kl_divergence = self.register_signal('KL Divergence') - def _learn_from_batch(self, batch): + @property + def is_on_policy(self) -> bool: + return False + def _learn_from_batch(self, batch): fetches = [self.networks['main'].online_network.output_heads[1].probability_loss, self.networks['main'].online_network.output_heads[1].bias_correction_loss, self.networks['main'].online_network.output_heads[1].kl_divergence] diff --git a/rl_coach/agents/actor_critic_agent.py b/rl_coach/agents/actor_critic_agent.py index 35c8bf9..b70f3a6 100644 --- a/rl_coach/agents/actor_critic_agent.py +++ b/rl_coach/agents/actor_critic_agent.py @@ -100,6 +100,10 @@ class ActorCriticAgent(PolicyOptimizationAgent): self.value_loss = self.register_signal('Value Loss') self.policy_loss = self.register_signal('Policy Loss') + @property + def is_on_policy(self) -> bool: + return True + # Discounting function used to calculate discounted returns. def discount(self, x, gamma): return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] diff --git a/rl_coach/agents/agent_interface.py b/rl_coach/agents/agent_interface.py index 87a9232..c476928 100644 --- a/rl_coach/agents/agent_interface.py +++ b/rl_coach/agents/agent_interface.py @@ -152,3 +152,8 @@ class AgentInterface(object): :return: None """ raise NotImplementedError("") + + @property + def is_on_policy(self) -> bool: + raise NotImplementedError("") + diff --git a/rl_coach/agents/bc_agent.py b/rl_coach/agents/bc_agent.py index 2208690..4ebe938 100644 --- a/rl_coach/agents/bc_agent.py +++ b/rl_coach/agents/bc_agent.py @@ -63,6 +63,10 @@ class BCAgent(ImitationAgent): def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None): super().__init__(agent_parameters, parent) + @property + def is_on_policy(self) -> bool: + return False + def learn_from_batch(self, batch): network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/bootstrapped_dqn_agent.py b/rl_coach/agents/bootstrapped_dqn_agent.py index b291dff..a324019 100644 --- a/rl_coach/agents/bootstrapped_dqn_agent.py +++ b/rl_coach/agents/bootstrapped_dqn_agent.py @@ -46,6 +46,10 @@ class BootstrappedDQNAgent(ValueOptimizationAgent): def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None): super().__init__(agent_parameters, parent) + @property + def is_on_policy(self) -> bool: + return False + def reset_internal_state(self): super().reset_internal_state() self.exploration_policy.select_head() diff --git a/rl_coach/agents/categorical_dqn_agent.py b/rl_coach/agents/categorical_dqn_agent.py index 34b0ec8..a1a726d 100644 --- a/rl_coach/agents/categorical_dqn_agent.py +++ b/rl_coach/agents/categorical_dqn_agent.py @@ -77,6 +77,10 @@ class CategoricalDQNAgent(ValueOptimizationAgent): super().__init__(agent_parameters, parent) self.z_values = np.linspace(self.ap.algorithm.v_min, self.ap.algorithm.v_max, self.ap.algorithm.atoms) + @property + def is_on_policy(self) -> bool: + return False + def distribution_prediction_to_q_values(self, prediction): return np.dot(prediction, self.z_values) diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index 1a9d202..98a39af 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -144,6 +144,10 @@ class ClippedPPOAgent(ActorCriticAgent): self.likelihood_ratio = self.register_signal('Likelihood Ratio') self.clipped_likelihood_ratio = self.register_signal('Clipped Likelihood Ratio') + @property + def is_on_policy(self) -> bool: + return True + def set_session(self, sess): super().set_session(sess) if self.ap.algorithm.normalization_stats is not None: diff --git a/rl_coach/agents/ddpg_agent.py b/rl_coach/agents/ddpg_agent.py index dbf3821..1c5b974 100644 --- a/rl_coach/agents/ddpg_agent.py +++ b/rl_coach/agents/ddpg_agent.py @@ -130,6 +130,10 @@ class DDPGAgent(ActorCriticAgent): self.TD_targets_signal = self.register_signal("TD targets") self.action_signal = self.register_signal("actions") + @property + def is_on_policy(self) -> bool: + return False + def learn_from_batch(self, batch): actor = self.networks['actor'] critic = self.networks['critic'] diff --git a/rl_coach/agents/dfp_agent.py b/rl_coach/agents/dfp_agent.py index 83c7412..2168434 100644 --- a/rl_coach/agents/dfp_agent.py +++ b/rl_coach/agents/dfp_agent.py @@ -141,6 +141,12 @@ class DFPAgent(Agent): self.current_goal = self.ap.algorithm.goal_vector self.target_measurements_scale_factors = None + @property + def is_on_policy(self) -> bool: + # This is only somewhat correct as the algorithm uses a very small (20k) ER keeping only recent samples seen. + # So, it is approximately on-policy (although if too be completely strict it is off-policy) + return True + def learn_from_batch(self, batch): network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/dqn_agent.py b/rl_coach/agents/dqn_agent.py index 6689b31..badaf67 100644 --- a/rl_coach/agents/dqn_agent.py +++ b/rl_coach/agents/dqn_agent.py @@ -71,6 +71,10 @@ class DQNAgent(ValueOptimizationAgent): def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None): super().__init__(agent_parameters, parent) + @property + def is_on_policy(self) -> bool: + return False + def select_actions(self, next_states, q_st_plus_1): return np.argmax(q_st_plus_1, 1) diff --git a/rl_coach/agents/imitation_agent.py b/rl_coach/agents/imitation_agent.py index 86e33e4..94ebee8 100644 --- a/rl_coach/agents/imitation_agent.py +++ b/rl_coach/agents/imitation_agent.py @@ -31,6 +31,10 @@ class ImitationAgent(Agent): super().__init__(agent_parameters, parent) self.imitation = True + @property + def is_on_policy(self) -> bool: + return False + def extract_action_values(self, prediction): return prediction.squeeze() diff --git a/rl_coach/agents/mmc_agent.py b/rl_coach/agents/mmc_agent.py index e0ce76d..b693345 100644 --- a/rl_coach/agents/mmc_agent.py +++ b/rl_coach/agents/mmc_agent.py @@ -50,6 +50,10 @@ class MixedMonteCarloAgent(ValueOptimizationAgent): super().__init__(agent_parameters, parent) self.mixing_rate = agent_parameters.algorithm.monte_carlo_mixing_rate + @property + def is_on_policy(self) -> bool: + return False + def learn_from_batch(self, batch): network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/n_step_q_agent.py b/rl_coach/agents/n_step_q_agent.py index 21b9239..9e187ea 100644 --- a/rl_coach/agents/n_step_q_agent.py +++ b/rl_coach/agents/n_step_q_agent.py @@ -92,6 +92,10 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent): self.q_values = self.register_signal('Q Values') self.value_loss = self.register_signal('Value Loss') + @property + def is_on_policy(self) -> bool: + return False + def learn_from_batch(self, batch): # batch contains a list of episodes to learn from network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/naf_agent.py b/rl_coach/agents/naf_agent.py index df7c60d..7b373c3 100644 --- a/rl_coach/agents/naf_agent.py +++ b/rl_coach/agents/naf_agent.py @@ -73,6 +73,10 @@ class NAFAgent(ValueOptimizationAgent): self.v_values = self.register_signal("V") self.TD_targets = self.register_signal("TD targets") + @property + def is_on_policy(self) -> bool: + return False + def learn_from_batch(self, batch): network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/nec_agent.py b/rl_coach/agents/nec_agent.py index a184001..f16eb56 100644 --- a/rl_coach/agents/nec_agent.py +++ b/rl_coach/agents/nec_agent.py @@ -120,6 +120,9 @@ class NECAgent(ValueOptimizationAgent): Episode(discount=self.ap.algorithm.discount, n_step=self.ap.algorithm.n_step, bootstrap_total_return_from_old_policy=self.ap.algorithm.bootstrap_total_return_from_old_policy) + @property + def is_on_policy(self) -> bool: + return False def learn_from_batch(self, batch): if not self.networks['main'].online_network.output_heads[0].DND.has_enough_entries(self.ap.algorithm.number_of_knn): diff --git a/rl_coach/agents/pal_agent.py b/rl_coach/agents/pal_agent.py index 44778d6..6e88e32 100644 --- a/rl_coach/agents/pal_agent.py +++ b/rl_coach/agents/pal_agent.py @@ -63,6 +63,10 @@ class PALAgent(ValueOptimizationAgent): self.persistent = agent_parameters.algorithm.persistent_advantage_learning self.monte_carlo_mixing_rate = agent_parameters.algorithm.monte_carlo_mixing_rate + @property + def is_on_policy(self) -> bool: + return False + def learn_from_batch(self, batch): network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/policy_gradients_agent.py b/rl_coach/agents/policy_gradients_agent.py index f64ed00..46f9527 100644 --- a/rl_coach/agents/policy_gradients_agent.py +++ b/rl_coach/agents/policy_gradients_agent.py @@ -91,6 +91,10 @@ class PolicyGradientsAgent(PolicyOptimizationAgent): self.returns_variance = self.register_signal('Returns Variance') self.last_gradient_update_step_idx = 0 + @property + def is_on_policy(self) -> bool: + return True + def learn_from_batch(self, batch): # batch contains a list of episodes to learn from network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/ppo_agent.py b/rl_coach/agents/ppo_agent.py index f45800a..7ff4bf4 100644 --- a/rl_coach/agents/ppo_agent.py +++ b/rl_coach/agents/ppo_agent.py @@ -149,6 +149,10 @@ class PPOAgent(ActorCriticAgent): self.total_kl_divergence_during_training_process = 0.0 self.unclipped_grads = self.register_signal('Grads (unclipped)') + @property + def is_on_policy(self) -> bool: + return True + def fill_advantages(self, batch): batch = Batch(batch) network_keys = self.ap.network_wrappers['critic'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/qr_dqn_agent.py b/rl_coach/agents/qr_dqn_agent.py index 4975523..a0dee56 100644 --- a/rl_coach/agents/qr_dqn_agent.py +++ b/rl_coach/agents/qr_dqn_agent.py @@ -67,6 +67,10 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent): super().__init__(agent_parameters, parent) self.quantile_probabilities = np.ones(self.ap.algorithm.atoms) / float(self.ap.algorithm.atoms) + @property + def is_on_policy(self) -> bool: + return False + def get_q_values(self, quantile_values): return np.dot(quantile_values, self.quantile_probabilities) diff --git a/rl_coach/agents/soft_actor_critic_agent.py b/rl_coach/agents/soft_actor_critic_agent.py index 9187124..39d3880 100644 --- a/rl_coach/agents/soft_actor_critic_agent.py +++ b/rl_coach/agents/soft_actor_critic_agent.py @@ -161,6 +161,10 @@ class SoftActorCriticAgent(PolicyOptimizationAgent): self.v_onl_ys = self.register_signal('V_onl_ys') self.action_signal = self.register_signal("actions") + @property + def is_on_policy(self) -> bool: + return False + def learn_from_batch(self, batch): ######################################### # need to update the following networks: diff --git a/rl_coach/agents/td3_agent.py b/rl_coach/agents/td3_agent.py index 44dbf3a..13b6b13 100644 --- a/rl_coach/agents/td3_agent.py +++ b/rl_coach/agents/td3_agent.py @@ -141,6 +141,10 @@ class TD3Agent(DDPGAgent): self.TD_targets_signal = self.register_signal("TD targets") self.action_signal = self.register_signal("actions") + @property + def is_on_policy(self) -> bool: + return False + def learn_from_batch(self, batch): actor = self.networks['actor'] critic = self.networks['critic']