1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

BCQ variant on top of DDQN (#276)

* kNN based model for predicting which actions to drop
* fix for seeds with batch rl
This commit is contained in:
Gal Leibovich
2019-04-16 17:06:23 +03:00
committed by GitHub
parent bdb9b224a8
commit 4741b0b916
11 changed files with 451 additions and 62 deletions

View File

@@ -139,6 +139,15 @@ class ValueOptimizationAgent(Agent):
self.agent_logger.create_signal_value('Doubly Robust', dr)
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
def get_reward_model_loss(self, batch: Batch):
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(
batch.states(network_keys))
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
return self.networks['reward_model'].train_and_sync_networks(
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
def improve_reward_model(self, epochs: int):
"""
Train a reward model to be used by the doubly-robust estimator
@@ -147,7 +156,6 @@ class ValueOptimizationAgent(Agent):
:return: None
"""
batch_size = self.ap.network_wrappers['reward_model'].batch_size
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
# this is fitted from the training dataset
for epoch in range(epochs):
@@ -155,10 +163,7 @@ class ValueOptimizationAgent(Agent):
total_transitions_processed = 0
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
batch = Batch(batch)
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
loss += self.networks['reward_model'].train_and_sync_networks(
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
loss += self.get_reward_model_loss(batch)
total_transitions_processed += batch.size
log = OrderedDict()
@@ -166,9 +171,3 @@ class ValueOptimizationAgent(Agent):
log['loss'] = loss / total_transitions_processed
screen.log_dict(log, prefix='Training Reward Model')