mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
bug fixes for OPE (#311)
This commit is contained in:
@@ -697,7 +697,7 @@ class Agent(AgentInterface):
|
||||
|
||||
# we either go sequentially through the entire replay buffer in the batch RL mode,
|
||||
# or sample randomly for the basic RL case.
|
||||
training_schedule = self.call_memory('get_shuffled_data_generator', batch_size) if \
|
||||
training_schedule = self.call_memory('get_shuffled_training_data_generator', batch_size) if \
|
||||
self.ap.is_batch_rl_training else [self.call_memory('sample', batch_size) for _ in
|
||||
range(self.ap.algorithm.num_consecutive_training_steps)]
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ class DDQNBCQAgent(DQNAgent):
|
||||
reward_model_loss = 0
|
||||
imitation_model_loss = 0
|
||||
total_transitions_processed = 0
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_training_data_generator', batch_size)):
|
||||
batch = Batch(batch)
|
||||
|
||||
# reward model
|
||||
|
||||
@@ -164,7 +164,7 @@ class ValueOptimizationAgent(Agent):
|
||||
for epoch in range(epochs):
|
||||
loss = 0
|
||||
total_transitions_processed = 0
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_training_data_generator', batch_size)):
|
||||
batch = Batch(batch)
|
||||
loss += self.get_reward_model_loss(batch)
|
||||
total_transitions_processed += batch.size
|
||||
|
||||
Reference in New Issue
Block a user