mirror of
https://github.com/gryf/coach.git
synced 2026-02-17 23:05:51 +01:00
N-step returns for rainbow (#67)
* n_step returns for rainbow * Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
@@ -26,10 +26,10 @@ def test_sum_tree():
|
||||
sum_tree.add(5, "5")
|
||||
assert sum_tree.total_value() == 20
|
||||
|
||||
assert sum_tree.get(2) == (0, 2.5, '2.5')
|
||||
assert sum_tree.get(3) == (1, 5.0, '5')
|
||||
assert sum_tree.get(10) == (2, 5.0, '5')
|
||||
assert sum_tree.get(13) == (3, 7.5, '7.5')
|
||||
assert sum_tree.get_element_by_partial_sum(2) == (0, 2.5, '2.5')
|
||||
assert sum_tree.get_element_by_partial_sum(3) == (1, 5.0, '5')
|
||||
assert sum_tree.get_element_by_partial_sum(10) == (2, 5.0, '5')
|
||||
assert sum_tree.get_element_by_partial_sum(13) == (3, 7.5, '7.5')
|
||||
|
||||
sum_tree.update(2, 10)
|
||||
assert sum_tree.__str__() == "[25.]\n[ 7.5 17.5]\n[ 2.5 5. 10. 7.5]\n"
|
||||
|
||||
@@ -41,8 +41,8 @@ def test_store_and_get(buffer: SingleEpisodeBuffer):
|
||||
# check that the episode is valid
|
||||
episode = buffer.get(0)
|
||||
assert episode.length() == 2
|
||||
assert episode.get_transition(0).total_return == 1 + 0.99
|
||||
assert episode.get_transition(1).total_return == 1
|
||||
assert episode.get_transition(0).n_step_discounted_rewards == 1 + 0.99
|
||||
assert episode.get_transition(1).n_step_discounted_rewards == 1
|
||||
assert buffer.mean_reward() == 1
|
||||
|
||||
# only one episode in the replay buffer
|
||||
|
||||
Reference in New Issue
Block a user