1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

OPE: Weighted Importance Sampling (#299)

This commit is contained in:
Gal Leibovich
2019-05-02 19:25:42 +03:00
committed by GitHub
parent 74db141d5e
commit 582921ffe3
8 changed files with 222 additions and 51 deletions

View File

@@ -522,6 +522,7 @@ class Agent(AgentInterface):
self.agent_logger.create_signal_value('Inverse Propensity Score', np.nan, overwrite=False) self.agent_logger.create_signal_value('Inverse Propensity Score', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Direct Method Reward', np.nan, overwrite=False) self.agent_logger.create_signal_value('Direct Method Reward', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Doubly Robust', np.nan, overwrite=False) self.agent_logger.create_signal_value('Doubly Robust', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Weighted Importance Sampling', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Sequential Doubly Robust', np.nan, overwrite=False) self.agent_logger.create_signal_value('Sequential Doubly Robust', np.nan, overwrite=False)
for signal in self.episode_signals: for signal in self.episode_signals:

View File

@@ -20,6 +20,7 @@ import numpy as np
from rl_coach.agents.agent import Agent from rl_coach.agents.agent import Agent
from rl_coach.core_types import ActionInfo, StateType, Batch from rl_coach.core_types import ActionInfo, StateType, Batch
from rl_coach.filters.filter import NoInputFilter
from rl_coach.logger import screen from rl_coach.logger import screen
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
from rl_coach.spaces import DiscreteActionSpace from rl_coach.spaces import DiscreteActionSpace
@@ -108,18 +109,18 @@ class ValueOptimizationAgent(Agent):
:return: None :return: None
""" """
assert self.ope_manager assert self.ope_manager
dataset_as_episodes = self.call_memory('get_all_complete_episodes_from_to',
(self.call_memory('get_last_training_set_episode_id') + 1,
self.call_memory('num_complete_episodes')))
if len(dataset_as_episodes) == 0:
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
'Consider decreasing its value.')
ips, dm, dr, seq_dr = self.ope_manager.evaluate( if not isinstance(self.pre_network_filter, NoInputFilter) and len(self.pre_network_filter.reward_filters) != 0:
dataset_as_episodes=dataset_as_episodes, raise ValueError("Defining a pre-network reward filter when OPEs are calculated will result in a mismatch "
"between q values (which are scaled), and actual rewards, which are not. It is advisable "
"to use an input_filter, if possible, instead, which will filter the transitions directly "
"in the replay buffer, affecting both the q_values and the rewards themselves. ")
ips, dm, dr, seq_dr, wis = self.ope_manager.evaluate(
evaluation_dataset_as_episodes=self.memory.evaluation_dataset_as_episodes,
evaluation_dataset_as_transitions=self.memory.evaluation_dataset_as_transitions,
batch_size=self.ap.network_wrappers['main'].batch_size, batch_size=self.ap.network_wrappers['main'].batch_size,
discount_factor=self.ap.algorithm.discount, discount_factor=self.ap.algorithm.discount,
reward_model=self.networks['reward_model'].online_network,
q_network=self.networks['main'].online_network, q_network=self.networks['main'].online_network,
network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys())) network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys()))
@@ -129,6 +130,7 @@ class ValueOptimizationAgent(Agent):
log['IPS'] = ips log['IPS'] = ips
log['DM'] = dm log['DM'] = dm
log['DR'] = dr log['DR'] = dr
log['WIS'] = wis
log['Sequential-DR'] = seq_dr log['Sequential-DR'] = seq_dr
screen.log_dict(log, prefix='Off-Policy Evaluation') screen.log_dict(log, prefix='Off-Policy Evaluation')
@@ -138,6 +140,7 @@ class ValueOptimizationAgent(Agent):
self.agent_logger.create_signal_value('Direct Method Reward', dm) self.agent_logger.create_signal_value('Direct Method Reward', dm)
self.agent_logger.create_signal_value('Doubly Robust', dr) self.agent_logger.create_signal_value('Doubly Robust', dr)
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr) self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
self.agent_logger.create_signal_value('Weighted Importance Sampling', wis)
def get_reward_model_loss(self, batch: Batch): def get_reward_model_loss(self, batch: Batch):
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys() network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()

View File

@@ -127,7 +127,10 @@ class BatchRLGraphManager(BasicRLGraphManager):
if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path: if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path:
self.heatup(self.heatup_steps) self.heatup(self.heatup_steps)
self.improve_reward_model() # from this point onwards, the dataset cannot be changed anymore. Allows for performance improvements.
self.level_managers[0].agents['agent'].memory.freeze()
self.initialize_ope_models_and_stats()
# improve # improve
if self.task_parameters.task_index is not None: if self.task_parameters.task_index is not None:
@@ -163,13 +166,26 @@ class BatchRLGraphManager(BasicRLGraphManager):
# we might want to evaluate vs. the simulator every now and then. # we might want to evaluate vs. the simulator every now and then.
break break
def improve_reward_model(self): def initialize_ope_models_and_stats(self):
""" """
:return: :return:
""" """
agent = self.level_managers[0].agents['agent']
screen.log_title("Training a regression model for estimating MDP rewards") screen.log_title("Training a regression model for estimating MDP rewards")
self.level_managers[0].agents['agent'].improve_reward_model(epochs=self.reward_model_num_epochs) agent.improve_reward_model(epochs=self.reward_model_num_epochs)
# prepare dataset to be consumed in the expected formats for OPE
agent.memory.prepare_evaluation_dataset()
screen.log_title("Collecting static statistics for OPE")
agent.ope_manager.gather_static_shared_stats(evaluation_dataset_as_transitions=
agent.memory.evaluation_dataset_as_transitions,
batch_size=agent.ap.network_wrappers['main'].batch_size,
reward_model=agent.networks['reward_model'].online_network,
network_keys=list(agent.ap.network_wrappers['main'].
input_embedders_parameters.keys()))
def run_off_policy_evaluation(self): def run_off_policy_evaluation(self):
""" """

View File

@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
# #
import ast import ast
from copy import deepcopy
import math import math
import pandas as pd import pandas as pd
@@ -64,6 +66,10 @@ class EpisodicExperienceReplay(Memory):
self.last_training_set_episode_id = None # used in batch-rl self.last_training_set_episode_id = None # used in batch-rl
self.last_training_set_transition_id = None # used in batch-rl self.last_training_set_transition_id = None # used in batch-rl
self.train_to_eval_ratio = train_to_eval_ratio # used in batch-rl self.train_to_eval_ratio = train_to_eval_ratio # used in batch-rl
self.evaluation_dataset_as_episodes = None
self.evaluation_dataset_as_transitions = None
self.frozen = False
def length(self, lock: bool = False) -> int: def length(self, lock: bool = False) -> int:
""" """
@@ -137,6 +143,8 @@ class EpisodicExperienceReplay(Memory):
Shuffle all the episodes in the replay buffer Shuffle all the episodes in the replay buffer
:return: :return:
""" """
self.assert_not_frozen()
random.shuffle(self._buffer) random.shuffle(self._buffer)
self.transitions = [t for e in self._buffer for t in e.transitions] self.transitions = [t for e in self._buffer for t in e.transitions]
@@ -256,6 +264,7 @@ class EpisodicExperienceReplay(Memory):
:param transition: a transition to store :param transition: a transition to store
:return: None :return: None
""" """
self.assert_not_frozen()
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition) super().store(transition)
@@ -281,6 +290,8 @@ class EpisodicExperienceReplay(Memory):
:param episode: the new episode to store :param episode: the new episode to store
:return: None :return: None
""" """
self.assert_not_frozen()
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode. # Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store_episode(episode) super().store_episode(episode)
@@ -322,6 +333,8 @@ class EpisodicExperienceReplay(Memory):
:param episode_index: the index of the episode to remove :param episode_index: the index of the episode to remove
:return: None :return: None
""" """
self.assert_not_frozen()
if len(self._buffer) > episode_index: if len(self._buffer) > episode_index:
episode_length = self._buffer[episode_index].length() episode_length = self._buffer[episode_index].length()
self._length -= 1 self._length -= 1
@@ -381,6 +394,7 @@ class EpisodicExperienceReplay(Memory):
Clean the memory by removing all the episodes Clean the memory by removing all the episodes
:return: None :return: None
""" """
self.assert_not_frozen()
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
self.transitions = [] self.transitions = []
@@ -409,6 +423,8 @@ class EpisodicExperienceReplay(Memory):
The csv file is assumed to include a list of transitions. The csv file is assumed to include a list of transitions.
:param csv_dataset: A construct which holds the dataset parameters :param csv_dataset: A construct which holds the dataset parameters
""" """
self.assert_not_frozen()
df = pd.read_csv(csv_dataset.filepath) df = pd.read_csv(csv_dataset.filepath)
if len(df) > self.max_size[1]: if len(df) > self.max_size[1]:
screen.warning("Warning! The number of transitions to load into the replay buffer ({}) is " screen.warning("Warning! The number of transitions to load into the replay buffer ({}) is "
@@ -446,3 +462,34 @@ class EpisodicExperienceReplay(Memory):
progress_bar.close() progress_bar.close()
self.shuffle_episodes() self.shuffle_episodes()
def freeze(self):
"""
Freezing the replay buffer does not allow any new transitions to be added to the memory.
Useful when working with a dataset (e.g. batch-rl or imitation learning).
:return: None
"""
self.frozen = True
def assert_not_frozen(self):
"""
Check that the memory is not frozen, and can be changed.
:return:
"""
assert self.frozen is False, "Memory is frozen, and cannot be changed."
def prepare_evaluation_dataset(self):
"""
Gather the memory content that will be used for off-policy evaluation in episodes and transitions format
:return:
"""
self.evaluation_dataset_as_episodes = deepcopy(
self.get_all_complete_episodes_from_to(self.get_last_training_set_episode_id() + 1,
self.num_complete_episodes()))
if len(self.evaluation_dataset_as_episodes) == 0:
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
'Consider decreasing its value.')
self.evaluation_dataset_as_transitions = [t for e in self.evaluation_dataset_as_episodes
for t in e.transitions]

View File

@@ -54,6 +54,7 @@ class ExperienceReplay(Memory):
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
self.reader_writer_lock = ReaderWriterLock() self.reader_writer_lock = ReaderWriterLock()
self.frozen = False
def length(self) -> int: def length(self) -> int:
""" """
@@ -135,6 +136,8 @@ class ExperienceReplay(Memory):
locks and then calls store with lock = True locks and then calls store with lock = True
:return: None :return: None
""" """
self.assert_not_frozen()
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition) super().store(transition)
if lock: if lock:
@@ -175,6 +178,8 @@ class ExperienceReplay(Memory):
:param transition_index: the index of the transition to remove :param transition_index: the index of the transition to remove
:return: None :return: None
""" """
self.assert_not_frozen()
if lock: if lock:
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
@@ -207,6 +212,8 @@ class ExperienceReplay(Memory):
Clean the memory by removing all the episodes Clean the memory by removing all the episodes
:return: None :return: None
""" """
self.assert_not_frozen()
if lock: if lock:
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
@@ -242,6 +249,8 @@ class ExperienceReplay(Memory):
The pickle file is assumed to include a list of transitions. The pickle file is assumed to include a list of transitions.
:param file_path: The path to a pickle file to restore :param file_path: The path to a pickle file to restore
""" """
self.assert_not_frozen()
with open(file_path, 'rb') as file: with open(file_path, 'rb') as file:
transitions = pickle.load(file) transitions = pickle.load(file)
num_transitions = len(transitions) num_transitions = len(transitions)
@@ -260,3 +269,17 @@ class ExperienceReplay(Memory):
progress_bar.close() progress_bar.close()
def freeze(self):
"""
Freezing the replay buffer does not allow any new transitions to be added to the memory.
Useful when working with a dataset (e.g. batch-rl or imitation learning).
:return: None
"""
self.frozen = True
def assert_not_frozen(self):
"""
Check that the memory is not frozen, and can be changed.
:return:
"""
assert self.frozen is False, "Memory is frozen, and cannot be changed."

View File

@@ -26,56 +26,60 @@ from rl_coach.off_policy_evaluators.rl.sequential_doubly_robust import Sequentia
from rl_coach.core_types import Transition from rl_coach.core_types import Transition
from rl_coach.off_policy_evaluators.rl.weighted_importance_sampling import WeightedImportanceSampling
OpeSharedStats = namedtuple("OpeSharedStats", ['all_reward_model_rewards', 'all_policy_probs', OpeSharedStats = namedtuple("OpeSharedStats", ['all_reward_model_rewards', 'all_policy_probs',
'all_v_values_reward_model_based', 'all_rewards', 'all_actions', 'all_v_values_reward_model_based', 'all_rewards', 'all_actions',
'all_old_policy_probs', 'new_policy_prob', 'rho_all_dataset']) 'all_old_policy_probs', 'new_policy_prob', 'rho_all_dataset'])
OpeEstimation = namedtuple("OpeEstimation", ['ips', 'dm', 'dr', 'seq_dr']) OpeEstimation = namedtuple("OpeEstimation", ['ips', 'dm', 'dr', 'seq_dr', 'wis'])
class OpeManager(object): class OpeManager(object):
def __init__(self): def __init__(self):
self.evaluation_dataset_as_transitions = None
self.doubly_robust = DoublyRobust() self.doubly_robust = DoublyRobust()
self.sequential_doubly_robust = SequentialDoublyRobust() self.sequential_doubly_robust = SequentialDoublyRobust()
self.weighted_importance_sampling = WeightedImportanceSampling()
self.all_reward_model_rewards = None
self.all_old_policy_probs = None
self.all_rewards = None
self.all_actions = None
self.is_gathered_static_shared_data = False
@staticmethod def _prepare_ope_shared_stats(self, evaluation_dataset_as_transitions: List[Transition], batch_size: int,
def _prepare_ope_shared_stats(dataset_as_transitions: List[Transition], batch_size: int, q_network: Architecture, network_keys: List) -> OpeSharedStats:
reward_model: Architecture, q_network: Architecture,
network_keys: List) -> OpeSharedStats:
""" """
Do the preparations needed for different estimators. Do the preparations needed for different estimators.
Some of the calcuations are shared, so we centralize all the work here. Some of the calcuations are shared, so we centralize all the work here.
:param dataset_as_transitions: The evaluation dataset in the form of transitions. :param evaluation_dataset_as_transitions: The evaluation dataset in the form of transitions.
:param batch_size: The batch size to use. :param batch_size: The batch size to use.
:param reward_model: A reward model to be used by DR :param reward_model: A reward model to be used by DR
:param q_network: The Q network whose its policy we evaluate. :param q_network: The Q network whose its policy we evaluate.
:param network_keys: The network keys used for feeding the neural networks. :param network_keys: The network keys used for feeding the neural networks.
:return: :return:
""" """
# IPS
all_reward_model_rewards, all_policy_probs, all_old_policy_probs = [], [], []
all_v_values_reward_model_based, all_v_values_q_model_based, all_rewards, all_actions = [], [], [], []
for i in range(math.ceil(len(dataset_as_transitions) / batch_size)): assert self.is_gathered_static_shared_data, "gather_static_shared_stats() should be called once before " \
batch = dataset_as_transitions[i * batch_size: (i + 1) * batch_size] "calling _prepare_ope_shared_stats()"
# IPS
all_policy_probs = []
all_v_values_reward_model_based, all_v_values_q_model_based = [], []
for i in range(math.ceil(len(evaluation_dataset_as_transitions) / batch_size)):
batch = evaluation_dataset_as_transitions[i * batch_size: (i + 1) * batch_size]
batch_for_inference = Batch(batch) batch_for_inference = Batch(batch)
all_reward_model_rewards.append(reward_model.predict(
batch_for_inference.states(network_keys)))
# we always use the first Q head to calculate OPEs. might want to change this in the future. # we always use the first Q head to calculate OPEs. might want to change this in the future.
# for instance, this means that for bootstrapped we always use the first QHead to calculate the OPEs. # for instance, this means that for bootstrapped dqn we always use the first QHead to calculate the OPEs.
q_values, sm_values = q_network.predict(batch_for_inference.states(network_keys), q_values, sm_values = q_network.predict(batch_for_inference.states(network_keys),
outputs=[q_network.output_heads[0].q_values, outputs=[q_network.output_heads[0].q_values,
q_network.output_heads[0].softmax]) q_network.output_heads[0].softmax])
all_policy_probs.append(sm_values) all_policy_probs.append(sm_values)
all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * all_reward_model_rewards[-1], axis=1)) all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * self.all_reward_model_rewards[i],
axis=1))
all_v_values_q_model_based.append(np.sum(all_policy_probs[-1] * q_values, axis=1)) all_v_values_q_model_based.append(np.sum(all_policy_probs[-1] * q_values, axis=1))
all_rewards.append(batch_for_inference.rewards())
all_actions.append(batch_for_inference.actions())
all_old_policy_probs.append(batch_for_inference.info('all_action_probabilities')
[range(len(batch_for_inference.actions())), batch_for_inference.actions()])
for j, t in enumerate(batch): for j, t in enumerate(batch):
t.update_info({ t.update_info({
@@ -85,26 +89,50 @@ class OpeManager(object):
}) })
all_reward_model_rewards = np.concatenate(all_reward_model_rewards, axis=0)
all_policy_probs = np.concatenate(all_policy_probs, axis=0) all_policy_probs = np.concatenate(all_policy_probs, axis=0)
all_v_values_reward_model_based = np.concatenate(all_v_values_reward_model_based, axis=0) all_v_values_reward_model_based = np.concatenate(all_v_values_reward_model_based, axis=0)
all_rewards = np.concatenate(all_rewards, axis=0)
all_actions = np.concatenate(all_actions, axis=0)
all_old_policy_probs = np.concatenate(all_old_policy_probs, axis=0)
# generate model probabilities # generate model probabilities
new_policy_prob = all_policy_probs[np.arange(all_actions.shape[0]), all_actions] new_policy_prob = all_policy_probs[np.arange(self.all_actions.shape[0]), self.all_actions]
rho_all_dataset = new_policy_prob / all_old_policy_probs rho_all_dataset = new_policy_prob / self.all_old_policy_probs
return OpeSharedStats(all_reward_model_rewards, all_policy_probs, all_v_values_reward_model_based, return OpeSharedStats(self.all_reward_model_rewards, all_policy_probs, all_v_values_reward_model_based,
all_rewards, all_actions, all_old_policy_probs, new_policy_prob, rho_all_dataset) self.all_rewards, self.all_actions, self.all_old_policy_probs, new_policy_prob,
rho_all_dataset)
def evaluate(self, dataset_as_episodes: List[Episode], batch_size: int, discount_factor: float, def gather_static_shared_stats(self, evaluation_dataset_as_transitions: List[Transition], batch_size: int,
reward_model: Architecture, q_network: Architecture, network_keys: List) -> OpeEstimation: reward_model: Architecture, network_keys: List) -> None:
all_reward_model_rewards = []
all_old_policy_probs = []
all_rewards = []
all_actions = []
for i in range(math.ceil(len(evaluation_dataset_as_transitions) / batch_size)):
batch = evaluation_dataset_as_transitions[i * batch_size: (i + 1) * batch_size]
batch_for_inference = Batch(batch)
all_reward_model_rewards.append(reward_model.predict(batch_for_inference.states(network_keys)))
all_rewards.append(batch_for_inference.rewards())
all_actions.append(batch_for_inference.actions())
all_old_policy_probs.append(batch_for_inference.info('all_action_probabilities')
[range(len(batch_for_inference.actions())),
batch_for_inference.actions()])
self.all_reward_model_rewards = np.concatenate(all_reward_model_rewards, axis=0)
self.all_old_policy_probs = np.concatenate(all_old_policy_probs, axis=0)
self.all_rewards = np.concatenate(all_rewards, axis=0)
self.all_actions = np.concatenate(all_actions, axis=0)
# mark that static shared data was collected and ready to be used
self.is_gathered_static_shared_data = True
def evaluate(self, evaluation_dataset_as_episodes: List[Episode], evaluation_dataset_as_transitions: List[Transition], batch_size: int,
discount_factor: float, q_network: Architecture, network_keys: List) -> OpeEstimation:
""" """
Run all the OPEs and get estimations of the current policy performance based on the evaluation dataset. Run all the OPEs and get estimations of the current policy performance based on the evaluation dataset.
:param dataset_as_episodes: The evaluation dataset. :param evaluation_dataset_as_episodes: The evaluation dataset in a form of episodes.
:param evaluation_dataset_as_transitions: The evaluation dataset in a form of transitions.
:param batch_size: Batch size to use for the estimators. :param batch_size: Batch size to use for the estimators.
:param discount_factor: The standard RL discount factor. :param discount_factor: The standard RL discount factor.
:param reward_model: A reward model to be used by DR :param reward_model: A reward model to be used by DR
@@ -113,12 +141,12 @@ class OpeManager(object):
:return: An OpeEstimation tuple which groups together all the OPE estimations :return: An OpeEstimation tuple which groups together all the OPE estimations
""" """
# TODO this seems kind of slow, review performance ope_shared_stats = self._prepare_ope_shared_stats(evaluation_dataset_as_transitions, batch_size, q_network,
dataset_as_transitions = [t for e in dataset_as_episodes for t in e.transitions] network_keys)
ope_shared_stats = self._prepare_ope_shared_stats(dataset_as_transitions, batch_size, reward_model,
q_network, network_keys)
ips, dm, dr = self.doubly_robust.evaluate(ope_shared_stats) ips, dm, dr = self.doubly_robust.evaluate(ope_shared_stats)
seq_dr = self.sequential_doubly_robust.evaluate(dataset_as_episodes, discount_factor) seq_dr = self.sequential_doubly_robust.evaluate(evaluation_dataset_as_episodes, discount_factor)
return OpeEstimation(ips, dm, dr, seq_dr) wis = self.weighted_importance_sampling.evaluate(evaluation_dataset_as_episodes)
return OpeEstimation(ips, dm, dr, seq_dr, wis)

View File

@@ -22,7 +22,7 @@ from rl_coach.core_types import Episode
class SequentialDoublyRobust(object): class SequentialDoublyRobust(object):
@staticmethod @staticmethod
def evaluate(dataset_as_episodes: List[Episode], discount_factor: float) -> float: def evaluate(evaluation_dataset_as_episodes: List[Episode], discount_factor: float) -> float:
""" """
Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset, Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset,
which was collected using other policy(ies). which was collected using other policy(ies).
@@ -35,7 +35,7 @@ class SequentialDoublyRobust(object):
# Sequential Doubly Robust # Sequential Doubly Robust
per_episode_seq_dr = [] per_episode_seq_dr = []
for episode in dataset_as_episodes: for episode in evaluation_dataset_as_episodes:
episode_seq_dr = 0 episode_seq_dr = 0
for transition in episode.transitions: for transition in episode.transitions:
rho = transition.info['softmax_policy_prob'][transition.action] / \ rho = transition.info['softmax_policy_prob'][transition.action] / \

View File

@@ -0,0 +1,53 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
import numpy as np
from rl_coach.core_types import Episode
class WeightedImportanceSampling(object):
# TODO rename and add PDIS
@staticmethod
def evaluate(evaluation_dataset_as_episodes: List[Episode]) -> float:
"""
Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset,
which was collected using other policy(ies).
References:
- Sutton, R. S. & Barto, A. G. Reinforcement Learning: An Introduction. Chapter 5.5.
- https://people.cs.umass.edu/~pthomas/papers/Thomas2015c.pdf
- http://videolectures.net/deeplearning2017_thomas_safe_rl/
:return: the evaluation score
"""
# Weighted Importance Sampling
per_episode_w_i = []
for episode in evaluation_dataset_as_episodes:
w_i = 1
for transition in episode.transitions:
w_i *= transition.info['softmax_policy_prob'][transition.action] / \
transition.info['all_action_probabilities'][transition.action]
per_episode_w_i.append(w_i)
total_w_i_sum_across_episodes = sum(per_episode_w_i)
wis = 0
for i, episode in enumerate(evaluation_dataset_as_episodes):
wis += per_episode_w_i[i]/total_w_i_sum_across_episodes * episode.transitions[0].n_step_discounted_rewards
return wis