1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

OPE: Weighted Importance Sampling (#299)

This commit is contained in:
Gal Leibovich
2019-05-02 19:25:42 +03:00
committed by GitHub
parent 74db141d5e
commit 582921ffe3
8 changed files with 222 additions and 51 deletions

View File

@@ -20,6 +20,7 @@ import numpy as np
from rl_coach.agents.agent import Agent
from rl_coach.core_types import ActionInfo, StateType, Batch
from rl_coach.filters.filter import NoInputFilter
from rl_coach.logger import screen
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
from rl_coach.spaces import DiscreteActionSpace
@@ -108,18 +109,18 @@ class ValueOptimizationAgent(Agent):
:return: None
"""
assert self.ope_manager
dataset_as_episodes = self.call_memory('get_all_complete_episodes_from_to',
(self.call_memory('get_last_training_set_episode_id') + 1,
self.call_memory('num_complete_episodes')))
if len(dataset_as_episodes) == 0:
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
'Consider decreasing its value.')
ips, dm, dr, seq_dr = self.ope_manager.evaluate(
dataset_as_episodes=dataset_as_episodes,
if not isinstance(self.pre_network_filter, NoInputFilter) and len(self.pre_network_filter.reward_filters) != 0:
raise ValueError("Defining a pre-network reward filter when OPEs are calculated will result in a mismatch "
"between q values (which are scaled), and actual rewards, which are not. It is advisable "
"to use an input_filter, if possible, instead, which will filter the transitions directly "
"in the replay buffer, affecting both the q_values and the rewards themselves. ")
ips, dm, dr, seq_dr, wis = self.ope_manager.evaluate(
evaluation_dataset_as_episodes=self.memory.evaluation_dataset_as_episodes,
evaluation_dataset_as_transitions=self.memory.evaluation_dataset_as_transitions,
batch_size=self.ap.network_wrappers['main'].batch_size,
discount_factor=self.ap.algorithm.discount,
reward_model=self.networks['reward_model'].online_network,
q_network=self.networks['main'].online_network,
network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys()))
@@ -129,6 +130,7 @@ class ValueOptimizationAgent(Agent):
log['IPS'] = ips
log['DM'] = dm
log['DR'] = dr
log['WIS'] = wis
log['Sequential-DR'] = seq_dr
screen.log_dict(log, prefix='Off-Policy Evaluation')
@@ -138,6 +140,7 @@ class ValueOptimizationAgent(Agent):
self.agent_logger.create_signal_value('Direct Method Reward', dm)
self.agent_logger.create_signal_value('Doubly Robust', dr)
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
self.agent_logger.create_signal_value('Weighted Importance Sampling', wis)
def get_reward_model_loss(self, batch: Batch):
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()