1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

OPE: Weighted Importance Sampling (#299)

This commit is contained in:
Gal Leibovich
2019-05-02 19:25:42 +03:00
committed by GitHub
parent 74db141d5e
commit 582921ffe3
8 changed files with 222 additions and 51 deletions

View File

@@ -127,7 +127,10 @@ class BatchRLGraphManager(BasicRLGraphManager):
if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path:
self.heatup(self.heatup_steps)
self.improve_reward_model()
# from this point onwards, the dataset cannot be changed anymore. Allows for performance improvements.
self.level_managers[0].agents['agent'].memory.freeze()
self.initialize_ope_models_and_stats()
# improve
if self.task_parameters.task_index is not None:
@@ -163,13 +166,26 @@ class BatchRLGraphManager(BasicRLGraphManager):
# we might want to evaluate vs. the simulator every now and then.
break
def improve_reward_model(self):
def initialize_ope_models_and_stats(self):
"""
:return:
"""
agent = self.level_managers[0].agents['agent']
screen.log_title("Training a regression model for estimating MDP rewards")
self.level_managers[0].agents['agent'].improve_reward_model(epochs=self.reward_model_num_epochs)
agent.improve_reward_model(epochs=self.reward_model_num_epochs)
# prepare dataset to be consumed in the expected formats for OPE
agent.memory.prepare_evaluation_dataset()
screen.log_title("Collecting static statistics for OPE")
agent.ope_manager.gather_static_shared_stats(evaluation_dataset_as_transitions=
agent.memory.evaluation_dataset_as_transitions,
batch_size=agent.ap.network_wrappers['main'].batch_size,
reward_model=agent.networks['reward_model'].online_network,
network_keys=list(agent.ap.network_wrappers['main'].
input_embedders_parameters.keys()))
def run_off_policy_evaluation(self):
"""