mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
BCQ variant on top of DDQN (#276)
* kNN based model for predicting which actions to drop * fix for seeds with batch rl
This commit is contained in:
@@ -139,6 +139,15 @@ class ValueOptimizationAgent(Agent):
|
||||
self.agent_logger.create_signal_value('Doubly Robust', dr)
|
||||
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
|
||||
|
||||
def get_reward_model_loss(self, batch: Batch):
|
||||
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
|
||||
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(
|
||||
batch.states(network_keys))
|
||||
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
|
||||
|
||||
return self.networks['reward_model'].train_and_sync_networks(
|
||||
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
|
||||
|
||||
def improve_reward_model(self, epochs: int):
|
||||
"""
|
||||
Train a reward model to be used by the doubly-robust estimator
|
||||
@@ -147,7 +156,6 @@ class ValueOptimizationAgent(Agent):
|
||||
:return: None
|
||||
"""
|
||||
batch_size = self.ap.network_wrappers['reward_model'].batch_size
|
||||
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
|
||||
|
||||
# this is fitted from the training dataset
|
||||
for epoch in range(epochs):
|
||||
@@ -155,10 +163,7 @@ class ValueOptimizationAgent(Agent):
|
||||
total_transitions_processed = 0
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||
batch = Batch(batch)
|
||||
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
|
||||
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
|
||||
loss += self.networks['reward_model'].train_and_sync_networks(
|
||||
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
|
||||
loss += self.get_reward_model_loss(batch)
|
||||
total_transitions_processed += batch.size
|
||||
|
||||
log = OrderedDict()
|
||||
@@ -166,9 +171,3 @@ class ValueOptimizationAgent(Agent):
|
||||
log['loss'] = loss / total_transitions_processed
|
||||
screen.log_dict(log, prefix='Training Reward Model')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user