mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Rainbow DQN agent (WIP - still missing dueling and n-step) + adding support for Prioritized ER for C51
This commit is contained in:
@@ -25,6 +25,7 @@ from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import StateType
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
|
||||
|
||||
@@ -104,11 +105,22 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
l = (np.floor(bj)).astype(int)
|
||||
m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj))
|
||||
m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l))
|
||||
|
||||
# total_loss = cross entropy between actual result above and predicted result for the given action
|
||||
TD_targets[batches, batch.actions()] = m
|
||||
|
||||
result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets)
|
||||
# update errors in prioritized replay buffer
|
||||
importance_weights = batch.info('weight') if isinstance(self.memory, PrioritizedExperienceReplay) else None
|
||||
|
||||
result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets,
|
||||
importance_weights=importance_weights)
|
||||
|
||||
total_loss, losses, unclipped_grads = result[:3]
|
||||
|
||||
# TODO: fix this spaghetti code
|
||||
if isinstance(self.memory, PrioritizedExperienceReplay):
|
||||
errors = losses[0][np.arange(batch.size), batch.actions()]
|
||||
self.memory.update_priorities(batch.info('idx'), errors)
|
||||
|
||||
return total_loss, losses, unclipped_grads
|
||||
|
||||
|
||||
Reference in New Issue
Block a user