mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
OPE: Weighted Importance Sampling (#299)
This commit is contained in:
@@ -127,7 +127,10 @@ class BatchRLGraphManager(BasicRLGraphManager):
|
||||
if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path:
|
||||
self.heatup(self.heatup_steps)
|
||||
|
||||
self.improve_reward_model()
|
||||
# from this point onwards, the dataset cannot be changed anymore. Allows for performance improvements.
|
||||
self.level_managers[0].agents['agent'].memory.freeze()
|
||||
|
||||
self.initialize_ope_models_and_stats()
|
||||
|
||||
# improve
|
||||
if self.task_parameters.task_index is not None:
|
||||
@@ -163,13 +166,26 @@ class BatchRLGraphManager(BasicRLGraphManager):
|
||||
# we might want to evaluate vs. the simulator every now and then.
|
||||
break
|
||||
|
||||
def improve_reward_model(self):
|
||||
def initialize_ope_models_and_stats(self):
|
||||
"""
|
||||
|
||||
:return:
|
||||
"""
|
||||
agent = self.level_managers[0].agents['agent']
|
||||
|
||||
screen.log_title("Training a regression model for estimating MDP rewards")
|
||||
self.level_managers[0].agents['agent'].improve_reward_model(epochs=self.reward_model_num_epochs)
|
||||
agent.improve_reward_model(epochs=self.reward_model_num_epochs)
|
||||
|
||||
# prepare dataset to be consumed in the expected formats for OPE
|
||||
agent.memory.prepare_evaluation_dataset()
|
||||
|
||||
screen.log_title("Collecting static statistics for OPE")
|
||||
agent.ope_manager.gather_static_shared_stats(evaluation_dataset_as_transitions=
|
||||
agent.memory.evaluation_dataset_as_transitions,
|
||||
batch_size=agent.ap.network_wrappers['main'].batch_size,
|
||||
reward_model=agent.networks['reward_model'].online_network,
|
||||
network_keys=list(agent.ap.network_wrappers['main'].
|
||||
input_embedders_parameters.keys()))
|
||||
|
||||
def run_off_policy_evaluation(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user