1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Rainbow DQN agent (WIP - still missing dueling and n-step) + adding support for Prioritized ER for C51

This commit is contained in:
Gal Leibovich
2018-08-30 15:11:51 +03:00
parent fd2f4b0852
commit bbe7ac3338
4 changed files with 228 additions and 1 deletions

View File

@@ -25,6 +25,7 @@ from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import StateType
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
from rl_coach.schedules import LinearSchedule
@@ -104,11 +105,22 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
l = (np.floor(bj)).astype(int)
m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj))
m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l))
# total_loss = cross entropy between actual result above and predicted result for the given action
TD_targets[batches, batch.actions()] = m
result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets)
# update errors in prioritized replay buffer
importance_weights = batch.info('weight') if isinstance(self.memory, PrioritizedExperienceReplay) else None
result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets,
importance_weights=importance_weights)
total_loss, losses, unclipped_grads = result[:3]
# TODO: fix this spaghetti code
if isinstance(self.memory, PrioritizedExperienceReplay):
errors = losses[0][np.arange(batch.size), batch.actions()]
self.memory.update_priorities(batch.info('idx'), errors)
return total_loss, losses, unclipped_grads