From acceb03ac0bd4c02e5e3e7812eb0003d148789dd Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Tue, 21 May 2019 16:39:11 +0300 Subject: [PATCH] bug fixes for OPE (#311) --- rl_coach/agents/agent.py | 2 +- rl_coach/agents/ddqn_bcq_agent.py | 2 +- rl_coach/agents/value_optimization_agent.py | 2 +- .../graph_managers/batch_rl_graph_manager.py | 6 ++-- .../episodic/episodic_experience_replay.py | 31 +++++++++++++------ .../non_episodic/experience_replay.py | 2 +- .../rl/sequential_doubly_robust.py | 6 ++-- .../rl/weighted_importance_sampling.py | 8 +++-- 8 files changed, 38 insertions(+), 21 deletions(-) diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index cd3f01e..1e93262 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -697,7 +697,7 @@ class Agent(AgentInterface): # we either go sequentially through the entire replay buffer in the batch RL mode, # or sample randomly for the basic RL case. - training_schedule = self.call_memory('get_shuffled_data_generator', batch_size) if \ + training_schedule = self.call_memory('get_shuffled_training_data_generator', batch_size) if \ self.ap.is_batch_rl_training else [self.call_memory('sample', batch_size) for _ in range(self.ap.algorithm.num_consecutive_training_steps)] diff --git a/rl_coach/agents/ddqn_bcq_agent.py b/rl_coach/agents/ddqn_bcq_agent.py index 32e90ef..ee9fb39 100644 --- a/rl_coach/agents/ddqn_bcq_agent.py +++ b/rl_coach/agents/ddqn_bcq_agent.py @@ -155,7 +155,7 @@ class DDQNBCQAgent(DQNAgent): reward_model_loss = 0 imitation_model_loss = 0 total_transitions_processed = 0 - for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)): + for i, batch in enumerate(self.call_memory('get_shuffled_training_data_generator', batch_size)): batch = Batch(batch) # reward model diff --git a/rl_coach/agents/value_optimization_agent.py b/rl_coach/agents/value_optimization_agent.py index e7af1ed..3a3ef8a 100644 --- a/rl_coach/agents/value_optimization_agent.py +++ b/rl_coach/agents/value_optimization_agent.py @@ -164,7 +164,7 @@ class ValueOptimizationAgent(Agent): for epoch in range(epochs): loss = 0 total_transitions_processed = 0 - for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)): + for i, batch in enumerate(self.call_memory('get_shuffled_training_data_generator', batch_size)): batch = Batch(batch) loss += self.get_reward_model_loss(batch) total_transitions_processed += batch.size diff --git a/rl_coach/graph_managers/batch_rl_graph_manager.py b/rl_coach/graph_managers/batch_rl_graph_manager.py index ccfb78d..e930fd6 100644 --- a/rl_coach/graph_managers/batch_rl_graph_manager.py +++ b/rl_coach/graph_managers/batch_rl_graph_manager.py @@ -173,12 +173,12 @@ class BatchRLGraphManager(BasicRLGraphManager): """ agent = self.level_managers[0].agents['agent'] - screen.log_title("Training a regression model for estimating MDP rewards") - agent.improve_reward_model(epochs=self.reward_model_num_epochs) - # prepare dataset to be consumed in the expected formats for OPE agent.memory.prepare_evaluation_dataset() + screen.log_title("Training a regression model for estimating MDP rewards") + agent.improve_reward_model(epochs=self.reward_model_num_epochs) + screen.log_title("Collecting static statistics for OPE") agent.ope_manager.gather_static_shared_stats(evaluation_dataset_as_transitions= agent.memory.evaluation_dataset_as_transitions, diff --git a/rl_coach/memories/episodic/episodic_experience_replay.py b/rl_coach/memories/episodic/episodic_experience_replay.py index 8bee63a..836756c 100644 --- a/rl_coach/memories/episodic/episodic_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_experience_replay.py @@ -148,7 +148,7 @@ class EpisodicExperienceReplay(Memory): random.shuffle(self._buffer) self.transitions = [t for e in self._buffer for t in e.transitions] - def get_shuffled_data_generator(self, size: int) -> List[Transition]: + def get_shuffled_training_data_generator(self, size: int) -> List[Transition]: """ Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs. If the requested size is larger than the number of samples available in the replay buffer then the batch will @@ -159,15 +159,6 @@ class EpisodicExperienceReplay(Memory): :return: a batch (list) of selected transitions from the replay buffer """ self.reader_writer_lock.lock_writing() - if self.last_training_set_transition_id is None: - if self.train_to_eval_ratio < 0 or self.train_to_eval_ratio >= 1: - raise ValueError('train_to_eval_ratio should be in the (0, 1] range.') - - transition = self.transitions[round(self.train_to_eval_ratio * self.num_transitions_in_complete_episodes())] - episode_num, episode = self.get_episode_for_transition(transition) - self.last_training_set_episode_id = episode_num - self.last_training_set_transition_id = \ - len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e]) shuffled_transition_indices = list(range(self.last_training_set_transition_id)) random.shuffle(shuffled_transition_indices) @@ -483,6 +474,9 @@ class EpisodicExperienceReplay(Memory): Gather the memory content that will be used for off-policy evaluation in episodes and transitions format :return: """ + self.reader_writer_lock.lock_writing_and_reading() + + self._split_training_and_evaluation_datasets() self.evaluation_dataset_as_episodes = deepcopy( self.get_all_complete_episodes_from_to(self.get_last_training_set_episode_id() + 1, self.num_complete_episodes())) @@ -493,3 +487,20 @@ class EpisodicExperienceReplay(Memory): self.evaluation_dataset_as_transitions = [t for e in self.evaluation_dataset_as_episodes for t in e.transitions] + self.reader_writer_lock.release_writing_and_reading() + + def _split_training_and_evaluation_datasets(self): + """ + If the data in the buffer was not split to training and evaluation yet, split it accordingly. + :return: None + """ + + if self.last_training_set_transition_id is None: + if self.train_to_eval_ratio < 0 or self.train_to_eval_ratio >= 1: + raise ValueError('train_to_eval_ratio should be in the (0, 1] range.') + + transition = self.transitions[round(self.train_to_eval_ratio * self.num_transitions_in_complete_episodes())] + episode_num, episode = self.get_episode_for_transition(transition) + self.last_training_set_episode_id = episode_num + self.last_training_set_transition_id = \ + len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e]) diff --git a/rl_coach/memories/non_episodic/experience_replay.py b/rl_coach/memories/non_episodic/experience_replay.py index ae580c5..1570c87 100644 --- a/rl_coach/memories/non_episodic/experience_replay.py +++ b/rl_coach/memories/non_episodic/experience_replay.py @@ -92,7 +92,7 @@ class ExperienceReplay(Memory): self.reader_writer_lock.release_writing() return batch - def get_shuffled_data_generator(self, size: int) -> List[Transition]: + def get_shuffled_training_data_generator(self, size: int) -> List[Transition]: """ Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs. If the requested size is larger than the number of samples available in the replay buffer then the batch will diff --git a/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py b/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py index e172a80..50d6ca0 100644 --- a/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py +++ b/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py @@ -26,7 +26,9 @@ class SequentialDoublyRobust(object): """ Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset, which was collected using other policy(ies). - + When the epsiodes are of changing lengths, this estimator might prove problematic due to its nature of recursion + of adding rewards up to the end of the episode (horizon). It will probably work best with episodes of fixed + length. Paper: https://arxiv.org/pdf/1511.03722.pdf :return: the evaluation score @@ -37,7 +39,7 @@ class SequentialDoublyRobust(object): for episode in evaluation_dataset_as_episodes: episode_seq_dr = 0 - for transition in episode.transitions: + for transition in reversed(episode.transitions): rho = transition.info['softmax_policy_prob'][transition.action] / \ transition.info['all_action_probabilities'][transition.action] episode_seq_dr = transition.info['v_value_q_model_based'] + rho * (transition.reward + discount_factor diff --git a/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py b/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py index f6b2a89..c97ad3c 100644 --- a/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py +++ b/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py @@ -46,8 +46,12 @@ class WeightedImportanceSampling(object): per_episode_w_i.append(w_i) total_w_i_sum_across_episodes = sum(per_episode_w_i) + wis = 0 - for i, episode in enumerate(evaluation_dataset_as_episodes): - wis += per_episode_w_i[i]/total_w_i_sum_across_episodes * episode.transitions[0].n_step_discounted_rewards + if total_w_i_sum_across_episodes != 0: + for i, episode in enumerate(evaluation_dataset_as_episodes): + if len(episode.transitions) != 0: + wis += per_episode_w_i[i] * episode.transitions[0].n_step_discounted_rewards + wis /= total_w_i_sum_across_episodes return wis