mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
OPE: Weighted Importance Sampling (#299)
This commit is contained in:
@@ -20,6 +20,7 @@ import numpy as np
|
||||
|
||||
from rl_coach.agents.agent import Agent
|
||||
from rl_coach.core_types import ActionInfo, StateType, Batch
|
||||
from rl_coach.filters.filter import NoInputFilter
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
@@ -108,18 +109,18 @@ class ValueOptimizationAgent(Agent):
|
||||
:return: None
|
||||
"""
|
||||
assert self.ope_manager
|
||||
dataset_as_episodes = self.call_memory('get_all_complete_episodes_from_to',
|
||||
(self.call_memory('get_last_training_set_episode_id') + 1,
|
||||
self.call_memory('num_complete_episodes')))
|
||||
if len(dataset_as_episodes) == 0:
|
||||
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
|
||||
'Consider decreasing its value.')
|
||||
|
||||
ips, dm, dr, seq_dr = self.ope_manager.evaluate(
|
||||
dataset_as_episodes=dataset_as_episodes,
|
||||
if not isinstance(self.pre_network_filter, NoInputFilter) and len(self.pre_network_filter.reward_filters) != 0:
|
||||
raise ValueError("Defining a pre-network reward filter when OPEs are calculated will result in a mismatch "
|
||||
"between q values (which are scaled), and actual rewards, which are not. It is advisable "
|
||||
"to use an input_filter, if possible, instead, which will filter the transitions directly "
|
||||
"in the replay buffer, affecting both the q_values and the rewards themselves. ")
|
||||
|
||||
ips, dm, dr, seq_dr, wis = self.ope_manager.evaluate(
|
||||
evaluation_dataset_as_episodes=self.memory.evaluation_dataset_as_episodes,
|
||||
evaluation_dataset_as_transitions=self.memory.evaluation_dataset_as_transitions,
|
||||
batch_size=self.ap.network_wrappers['main'].batch_size,
|
||||
discount_factor=self.ap.algorithm.discount,
|
||||
reward_model=self.networks['reward_model'].online_network,
|
||||
q_network=self.networks['main'].online_network,
|
||||
network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys()))
|
||||
|
||||
@@ -129,6 +130,7 @@ class ValueOptimizationAgent(Agent):
|
||||
log['IPS'] = ips
|
||||
log['DM'] = dm
|
||||
log['DR'] = dr
|
||||
log['WIS'] = wis
|
||||
log['Sequential-DR'] = seq_dr
|
||||
screen.log_dict(log, prefix='Off-Policy Evaluation')
|
||||
|
||||
@@ -138,6 +140,7 @@ class ValueOptimizationAgent(Agent):
|
||||
self.agent_logger.create_signal_value('Direct Method Reward', dm)
|
||||
self.agent_logger.create_signal_value('Doubly Robust', dr)
|
||||
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
|
||||
self.agent_logger.create_signal_value('Weighted Importance Sampling', wis)
|
||||
|
||||
def get_reward_model_loss(self, batch: Batch):
|
||||
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
|
||||
|
||||
Reference in New Issue
Block a user