diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 53412a4..cd3f01e 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -522,6 +522,7 @@ class Agent(AgentInterface): self.agent_logger.create_signal_value('Inverse Propensity Score', np.nan, overwrite=False) self.agent_logger.create_signal_value('Direct Method Reward', np.nan, overwrite=False) self.agent_logger.create_signal_value('Doubly Robust', np.nan, overwrite=False) + self.agent_logger.create_signal_value('Weighted Importance Sampling', np.nan, overwrite=False) self.agent_logger.create_signal_value('Sequential Doubly Robust', np.nan, overwrite=False) for signal in self.episode_signals: diff --git a/rl_coach/agents/value_optimization_agent.py b/rl_coach/agents/value_optimization_agent.py index 60f0a1d..e7af1ed 100644 --- a/rl_coach/agents/value_optimization_agent.py +++ b/rl_coach/agents/value_optimization_agent.py @@ -20,6 +20,7 @@ import numpy as np from rl_coach.agents.agent import Agent from rl_coach.core_types import ActionInfo, StateType, Batch +from rl_coach.filters.filter import NoInputFilter from rl_coach.logger import screen from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay from rl_coach.spaces import DiscreteActionSpace @@ -108,18 +109,18 @@ class ValueOptimizationAgent(Agent): :return: None """ assert self.ope_manager - dataset_as_episodes = self.call_memory('get_all_complete_episodes_from_to', - (self.call_memory('get_last_training_set_episode_id') + 1, - self.call_memory('num_complete_episodes'))) - if len(dataset_as_episodes) == 0: - raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. ' - 'Consider decreasing its value.') - ips, dm, dr, seq_dr = self.ope_manager.evaluate( - dataset_as_episodes=dataset_as_episodes, + if not isinstance(self.pre_network_filter, NoInputFilter) and len(self.pre_network_filter.reward_filters) != 0: + raise ValueError("Defining a pre-network reward filter when OPEs are calculated will result in a mismatch " + "between q values (which are scaled), and actual rewards, which are not. It is advisable " + "to use an input_filter, if possible, instead, which will filter the transitions directly " + "in the replay buffer, affecting both the q_values and the rewards themselves. ") + + ips, dm, dr, seq_dr, wis = self.ope_manager.evaluate( + evaluation_dataset_as_episodes=self.memory.evaluation_dataset_as_episodes, + evaluation_dataset_as_transitions=self.memory.evaluation_dataset_as_transitions, batch_size=self.ap.network_wrappers['main'].batch_size, discount_factor=self.ap.algorithm.discount, - reward_model=self.networks['reward_model'].online_network, q_network=self.networks['main'].online_network, network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys())) @@ -129,6 +130,7 @@ class ValueOptimizationAgent(Agent): log['IPS'] = ips log['DM'] = dm log['DR'] = dr + log['WIS'] = wis log['Sequential-DR'] = seq_dr screen.log_dict(log, prefix='Off-Policy Evaluation') @@ -138,6 +140,7 @@ class ValueOptimizationAgent(Agent): self.agent_logger.create_signal_value('Direct Method Reward', dm) self.agent_logger.create_signal_value('Doubly Robust', dr) self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr) + self.agent_logger.create_signal_value('Weighted Importance Sampling', wis) def get_reward_model_loss(self, batch: Batch): network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys() diff --git a/rl_coach/graph_managers/batch_rl_graph_manager.py b/rl_coach/graph_managers/batch_rl_graph_manager.py index 2e148e2..ccfb78d 100644 --- a/rl_coach/graph_managers/batch_rl_graph_manager.py +++ b/rl_coach/graph_managers/batch_rl_graph_manager.py @@ -127,7 +127,10 @@ class BatchRLGraphManager(BasicRLGraphManager): if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path: self.heatup(self.heatup_steps) - self.improve_reward_model() + # from this point onwards, the dataset cannot be changed anymore. Allows for performance improvements. + self.level_managers[0].agents['agent'].memory.freeze() + + self.initialize_ope_models_and_stats() # improve if self.task_parameters.task_index is not None: @@ -163,13 +166,26 @@ class BatchRLGraphManager(BasicRLGraphManager): # we might want to evaluate vs. the simulator every now and then. break - def improve_reward_model(self): + def initialize_ope_models_and_stats(self): """ :return: """ + agent = self.level_managers[0].agents['agent'] + screen.log_title("Training a regression model for estimating MDP rewards") - self.level_managers[0].agents['agent'].improve_reward_model(epochs=self.reward_model_num_epochs) + agent.improve_reward_model(epochs=self.reward_model_num_epochs) + + # prepare dataset to be consumed in the expected formats for OPE + agent.memory.prepare_evaluation_dataset() + + screen.log_title("Collecting static statistics for OPE") + agent.ope_manager.gather_static_shared_stats(evaluation_dataset_as_transitions= + agent.memory.evaluation_dataset_as_transitions, + batch_size=agent.ap.network_wrappers['main'].batch_size, + reward_model=agent.networks['reward_model'].online_network, + network_keys=list(agent.ap.network_wrappers['main']. + input_embedders_parameters.keys())) def run_off_policy_evaluation(self): """ diff --git a/rl_coach/memories/episodic/episodic_experience_replay.py b/rl_coach/memories/episodic/episodic_experience_replay.py index f556c9d..8bee63a 100644 --- a/rl_coach/memories/episodic/episodic_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_experience_replay.py @@ -15,6 +15,8 @@ # limitations under the License. # import ast +from copy import deepcopy + import math import pandas as pd @@ -64,6 +66,10 @@ class EpisodicExperienceReplay(Memory): self.last_training_set_episode_id = None # used in batch-rl self.last_training_set_transition_id = None # used in batch-rl self.train_to_eval_ratio = train_to_eval_ratio # used in batch-rl + self.evaluation_dataset_as_episodes = None + self.evaluation_dataset_as_transitions = None + + self.frozen = False def length(self, lock: bool = False) -> int: """ @@ -137,6 +143,8 @@ class EpisodicExperienceReplay(Memory): Shuffle all the episodes in the replay buffer :return: """ + self.assert_not_frozen() + random.shuffle(self._buffer) self.transitions = [t for e in self._buffer for t in e.transitions] @@ -256,6 +264,7 @@ class EpisodicExperienceReplay(Memory): :param transition: a transition to store :return: None """ + self.assert_not_frozen() # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. super().store(transition) @@ -281,6 +290,8 @@ class EpisodicExperienceReplay(Memory): :param episode: the new episode to store :return: None """ + self.assert_not_frozen() + # Calling super.store() so that in case a memory backend is used, the memory backend can store this episode. super().store_episode(episode) @@ -322,6 +333,8 @@ class EpisodicExperienceReplay(Memory): :param episode_index: the index of the episode to remove :return: None """ + self.assert_not_frozen() + if len(self._buffer) > episode_index: episode_length = self._buffer[episode_index].length() self._length -= 1 @@ -381,6 +394,7 @@ class EpisodicExperienceReplay(Memory): Clean the memory by removing all the episodes :return: None """ + self.assert_not_frozen() self.reader_writer_lock.lock_writing_and_reading() self.transitions = [] @@ -409,6 +423,8 @@ class EpisodicExperienceReplay(Memory): The csv file is assumed to include a list of transitions. :param csv_dataset: A construct which holds the dataset parameters """ + self.assert_not_frozen() + df = pd.read_csv(csv_dataset.filepath) if len(df) > self.max_size[1]: screen.warning("Warning! The number of transitions to load into the replay buffer ({}) is " @@ -446,3 +462,34 @@ class EpisodicExperienceReplay(Memory): progress_bar.close() self.shuffle_episodes() + + def freeze(self): + """ + Freezing the replay buffer does not allow any new transitions to be added to the memory. + Useful when working with a dataset (e.g. batch-rl or imitation learning). + :return: None + """ + self.frozen = True + + def assert_not_frozen(self): + """ + Check that the memory is not frozen, and can be changed. + :return: + """ + assert self.frozen is False, "Memory is frozen, and cannot be changed." + + def prepare_evaluation_dataset(self): + """ + Gather the memory content that will be used for off-policy evaluation in episodes and transitions format + :return: + """ + self.evaluation_dataset_as_episodes = deepcopy( + self.get_all_complete_episodes_from_to(self.get_last_training_set_episode_id() + 1, + self.num_complete_episodes())) + + if len(self.evaluation_dataset_as_episodes) == 0: + raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. ' + 'Consider decreasing its value.') + + self.evaluation_dataset_as_transitions = [t for e in self.evaluation_dataset_as_episodes + for t in e.transitions] diff --git a/rl_coach/memories/non_episodic/experience_replay.py b/rl_coach/memories/non_episodic/experience_replay.py index 4226ba4..ae580c5 100644 --- a/rl_coach/memories/non_episodic/experience_replay.py +++ b/rl_coach/memories/non_episodic/experience_replay.py @@ -54,6 +54,7 @@ class ExperienceReplay(Memory): self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling self.reader_writer_lock = ReaderWriterLock() + self.frozen = False def length(self) -> int: """ @@ -135,6 +136,8 @@ class ExperienceReplay(Memory): locks and then calls store with lock = True :return: None """ + self.assert_not_frozen() + # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. super().store(transition) if lock: @@ -175,6 +178,8 @@ class ExperienceReplay(Memory): :param transition_index: the index of the transition to remove :return: None """ + self.assert_not_frozen() + if lock: self.reader_writer_lock.lock_writing_and_reading() @@ -207,6 +212,8 @@ class ExperienceReplay(Memory): Clean the memory by removing all the episodes :return: None """ + self.assert_not_frozen() + if lock: self.reader_writer_lock.lock_writing_and_reading() @@ -242,6 +249,8 @@ class ExperienceReplay(Memory): The pickle file is assumed to include a list of transitions. :param file_path: The path to a pickle file to restore """ + self.assert_not_frozen() + with open(file_path, 'rb') as file: transitions = pickle.load(file) num_transitions = len(transitions) @@ -260,3 +269,17 @@ class ExperienceReplay(Memory): progress_bar.close() + def freeze(self): + """ + Freezing the replay buffer does not allow any new transitions to be added to the memory. + Useful when working with a dataset (e.g. batch-rl or imitation learning). + :return: None + """ + self.frozen = True + + def assert_not_frozen(self): + """ + Check that the memory is not frozen, and can be changed. + :return: + """ + assert self.frozen is False, "Memory is frozen, and cannot be changed." diff --git a/rl_coach/off_policy_evaluators/ope_manager.py b/rl_coach/off_policy_evaluators/ope_manager.py index 81b11ea..514d4ff 100644 --- a/rl_coach/off_policy_evaluators/ope_manager.py +++ b/rl_coach/off_policy_evaluators/ope_manager.py @@ -26,56 +26,60 @@ from rl_coach.off_policy_evaluators.rl.sequential_doubly_robust import Sequentia from rl_coach.core_types import Transition +from rl_coach.off_policy_evaluators.rl.weighted_importance_sampling import WeightedImportanceSampling + OpeSharedStats = namedtuple("OpeSharedStats", ['all_reward_model_rewards', 'all_policy_probs', 'all_v_values_reward_model_based', 'all_rewards', 'all_actions', 'all_old_policy_probs', 'new_policy_prob', 'rho_all_dataset']) -OpeEstimation = namedtuple("OpeEstimation", ['ips', 'dm', 'dr', 'seq_dr']) +OpeEstimation = namedtuple("OpeEstimation", ['ips', 'dm', 'dr', 'seq_dr', 'wis']) class OpeManager(object): def __init__(self): + self.evaluation_dataset_as_transitions = None self.doubly_robust = DoublyRobust() self.sequential_doubly_robust = SequentialDoublyRobust() + self.weighted_importance_sampling = WeightedImportanceSampling() + self.all_reward_model_rewards = None + self.all_old_policy_probs = None + self.all_rewards = None + self.all_actions = None + self.is_gathered_static_shared_data = False - @staticmethod - def _prepare_ope_shared_stats(dataset_as_transitions: List[Transition], batch_size: int, - reward_model: Architecture, q_network: Architecture, - network_keys: List) -> OpeSharedStats: + def _prepare_ope_shared_stats(self, evaluation_dataset_as_transitions: List[Transition], batch_size: int, + q_network: Architecture, network_keys: List) -> OpeSharedStats: """ Do the preparations needed for different estimators. Some of the calcuations are shared, so we centralize all the work here. - :param dataset_as_transitions: The evaluation dataset in the form of transitions. + :param evaluation_dataset_as_transitions: The evaluation dataset in the form of transitions. :param batch_size: The batch size to use. :param reward_model: A reward model to be used by DR :param q_network: The Q network whose its policy we evaluate. :param network_keys: The network keys used for feeding the neural networks. :return: """ - # IPS - all_reward_model_rewards, all_policy_probs, all_old_policy_probs = [], [], [] - all_v_values_reward_model_based, all_v_values_q_model_based, all_rewards, all_actions = [], [], [], [] - for i in range(math.ceil(len(dataset_as_transitions) / batch_size)): - batch = dataset_as_transitions[i * batch_size: (i + 1) * batch_size] + assert self.is_gathered_static_shared_data, "gather_static_shared_stats() should be called once before " \ + "calling _prepare_ope_shared_stats()" + # IPS + all_policy_probs = [] + all_v_values_reward_model_based, all_v_values_q_model_based = [], [] + + for i in range(math.ceil(len(evaluation_dataset_as_transitions) / batch_size)): + batch = evaluation_dataset_as_transitions[i * batch_size: (i + 1) * batch_size] batch_for_inference = Batch(batch) - all_reward_model_rewards.append(reward_model.predict( - batch_for_inference.states(network_keys))) - # we always use the first Q head to calculate OPEs. might want to change this in the future. - # for instance, this means that for bootstrapped we always use the first QHead to calculate the OPEs. + # for instance, this means that for bootstrapped dqn we always use the first QHead to calculate the OPEs. q_values, sm_values = q_network.predict(batch_for_inference.states(network_keys), outputs=[q_network.output_heads[0].q_values, q_network.output_heads[0].softmax]) all_policy_probs.append(sm_values) - all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * all_reward_model_rewards[-1], axis=1)) + all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * self.all_reward_model_rewards[i], + axis=1)) all_v_values_q_model_based.append(np.sum(all_policy_probs[-1] * q_values, axis=1)) - all_rewards.append(batch_for_inference.rewards()) - all_actions.append(batch_for_inference.actions()) - all_old_policy_probs.append(batch_for_inference.info('all_action_probabilities') - [range(len(batch_for_inference.actions())), batch_for_inference.actions()]) for j, t in enumerate(batch): t.update_info({ @@ -85,26 +89,50 @@ class OpeManager(object): }) - all_reward_model_rewards = np.concatenate(all_reward_model_rewards, axis=0) all_policy_probs = np.concatenate(all_policy_probs, axis=0) all_v_values_reward_model_based = np.concatenate(all_v_values_reward_model_based, axis=0) - all_rewards = np.concatenate(all_rewards, axis=0) - all_actions = np.concatenate(all_actions, axis=0) - all_old_policy_probs = np.concatenate(all_old_policy_probs, axis=0) # generate model probabilities - new_policy_prob = all_policy_probs[np.arange(all_actions.shape[0]), all_actions] - rho_all_dataset = new_policy_prob / all_old_policy_probs + new_policy_prob = all_policy_probs[np.arange(self.all_actions.shape[0]), self.all_actions] + rho_all_dataset = new_policy_prob / self.all_old_policy_probs - return OpeSharedStats(all_reward_model_rewards, all_policy_probs, all_v_values_reward_model_based, - all_rewards, all_actions, all_old_policy_probs, new_policy_prob, rho_all_dataset) + return OpeSharedStats(self.all_reward_model_rewards, all_policy_probs, all_v_values_reward_model_based, + self.all_rewards, self.all_actions, self.all_old_policy_probs, new_policy_prob, + rho_all_dataset) - def evaluate(self, dataset_as_episodes: List[Episode], batch_size: int, discount_factor: float, - reward_model: Architecture, q_network: Architecture, network_keys: List) -> OpeEstimation: + def gather_static_shared_stats(self, evaluation_dataset_as_transitions: List[Transition], batch_size: int, + reward_model: Architecture, network_keys: List) -> None: + all_reward_model_rewards = [] + all_old_policy_probs = [] + all_rewards = [] + all_actions = [] + + for i in range(math.ceil(len(evaluation_dataset_as_transitions) / batch_size)): + batch = evaluation_dataset_as_transitions[i * batch_size: (i + 1) * batch_size] + batch_for_inference = Batch(batch) + + all_reward_model_rewards.append(reward_model.predict(batch_for_inference.states(network_keys))) + all_rewards.append(batch_for_inference.rewards()) + all_actions.append(batch_for_inference.actions()) + all_old_policy_probs.append(batch_for_inference.info('all_action_probabilities') + [range(len(batch_for_inference.actions())), + batch_for_inference.actions()]) + + self.all_reward_model_rewards = np.concatenate(all_reward_model_rewards, axis=0) + self.all_old_policy_probs = np.concatenate(all_old_policy_probs, axis=0) + self.all_rewards = np.concatenate(all_rewards, axis=0) + self.all_actions = np.concatenate(all_actions, axis=0) + + # mark that static shared data was collected and ready to be used + self.is_gathered_static_shared_data = True + + def evaluate(self, evaluation_dataset_as_episodes: List[Episode], evaluation_dataset_as_transitions: List[Transition], batch_size: int, + discount_factor: float, q_network: Architecture, network_keys: List) -> OpeEstimation: """ Run all the OPEs and get estimations of the current policy performance based on the evaluation dataset. - :param dataset_as_episodes: The evaluation dataset. + :param evaluation_dataset_as_episodes: The evaluation dataset in a form of episodes. + :param evaluation_dataset_as_transitions: The evaluation dataset in a form of transitions. :param batch_size: Batch size to use for the estimators. :param discount_factor: The standard RL discount factor. :param reward_model: A reward model to be used by DR @@ -113,12 +141,12 @@ class OpeManager(object): :return: An OpeEstimation tuple which groups together all the OPE estimations """ - # TODO this seems kind of slow, review performance - dataset_as_transitions = [t for e in dataset_as_episodes for t in e.transitions] - ope_shared_stats = self._prepare_ope_shared_stats(dataset_as_transitions, batch_size, reward_model, - q_network, network_keys) + ope_shared_stats = self._prepare_ope_shared_stats(evaluation_dataset_as_transitions, batch_size, q_network, + network_keys) ips, dm, dr = self.doubly_robust.evaluate(ope_shared_stats) - seq_dr = self.sequential_doubly_robust.evaluate(dataset_as_episodes, discount_factor) - return OpeEstimation(ips, dm, dr, seq_dr) + seq_dr = self.sequential_doubly_robust.evaluate(evaluation_dataset_as_episodes, discount_factor) + wis = self.weighted_importance_sampling.evaluate(evaluation_dataset_as_episodes) + + return OpeEstimation(ips, dm, dr, seq_dr, wis) diff --git a/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py b/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py index 633a747..e172a80 100644 --- a/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py +++ b/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py @@ -22,7 +22,7 @@ from rl_coach.core_types import Episode class SequentialDoublyRobust(object): @staticmethod - def evaluate(dataset_as_episodes: List[Episode], discount_factor: float) -> float: + def evaluate(evaluation_dataset_as_episodes: List[Episode], discount_factor: float) -> float: """ Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset, which was collected using other policy(ies). @@ -35,7 +35,7 @@ class SequentialDoublyRobust(object): # Sequential Doubly Robust per_episode_seq_dr = [] - for episode in dataset_as_episodes: + for episode in evaluation_dataset_as_episodes: episode_seq_dr = 0 for transition in episode.transitions: rho = transition.info['softmax_policy_prob'][transition.action] / \ diff --git a/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py b/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py new file mode 100644 index 0000000..f6b2a89 --- /dev/null +++ b/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py @@ -0,0 +1,53 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List +import numpy as np + +from rl_coach.core_types import Episode + + +class WeightedImportanceSampling(object): +# TODO rename and add PDIS + @staticmethod + def evaluate(evaluation_dataset_as_episodes: List[Episode]) -> float: + """ + Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset, + which was collected using other policy(ies). + + References: + - Sutton, R. S. & Barto, A. G. Reinforcement Learning: An Introduction. Chapter 5.5. + - https://people.cs.umass.edu/~pthomas/papers/Thomas2015c.pdf + - http://videolectures.net/deeplearning2017_thomas_safe_rl/ + + :return: the evaluation score + """ + + # Weighted Importance Sampling + per_episode_w_i = [] + + for episode in evaluation_dataset_as_episodes: + w_i = 1 + for transition in episode.transitions: + w_i *= transition.info['softmax_policy_prob'][transition.action] / \ + transition.info['all_action_probabilities'][transition.action] + per_episode_w_i.append(w_i) + + total_w_i_sum_across_episodes = sum(per_episode_w_i) + wis = 0 + for i, episode in enumerate(evaluation_dataset_as_episodes): + wis += per_episode_w_i[i]/total_w_i_sum_across_episodes * episode.transitions[0].n_step_discounted_rewards + + return wis