mirror of
https://github.com/gryf/coach.git
synced 2026-04-26 10:41:28 +02:00
pre-release 0.10.0
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
# nasty hack to deal with issue #46
|
||||
import os
|
||||
import sys
|
||||
|
||||
from rl_coach.memories.episodic.episodic_hindsight_experience_replay import EpisodicHindsightExperienceReplayParameters
|
||||
from rl_coach.spaces import GoalsSpace, ReachingGoal
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
# print(sys.path)
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import Transition, Episode
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
from rl_coach.memories.episodic.episodic_hindsight_experience_replay import EpisodicHindsightExperienceReplay, \
|
||||
HindsightGoalSelectionMethod
|
||||
|
||||
|
||||
#TODO: change from defining a new class to creating an instance from the parameters
|
||||
class Parameters(EpisodicHindsightExperienceReplayParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_size = (MemoryGranularity.Transitions, 100)
|
||||
self.hindsight_transitions_per_regular_transition = 4
|
||||
self.hindsight_goal_selection_method = HindsightGoalSelectionMethod.Future
|
||||
self.goals_space = GoalsSpace(goal_name='observation',
|
||||
reward_type=ReachingGoal(distance_from_goal_threshold=0.1),
|
||||
distance_metric=GoalsSpace.DistanceMetric.Euclidean)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def episode():
|
||||
episode = []
|
||||
for i in range(10):
|
||||
episode.append(Transition(
|
||||
state={'observation': np.array([i]), 'desired_goal': np.array([i]), 'achieved_goal': np.array([i])},
|
||||
action=i,
|
||||
))
|
||||
return episode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def her():
|
||||
params = Parameters().__dict__
|
||||
|
||||
import inspect
|
||||
args = set(inspect.getfullargspec(EpisodicHindsightExperienceReplay.__init__).args).intersection(params)
|
||||
params = {k: params[k] for k in args}
|
||||
|
||||
return EpisodicHindsightExperienceReplay(**params)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_sample_goal(her, episode):
|
||||
assert her._sample_goal(episode, 8) == 9
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_sample_goal_range(her, episode):
|
||||
unseen_goals = set(range(1, 9))
|
||||
for _ in range(500):
|
||||
unseen_goals -= set([int(her._sample_goal(episode, 0))])
|
||||
if not unseen_goals:
|
||||
return
|
||||
|
||||
assert unseen_goals == set()
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_update_episode(her):
|
||||
episode = Episode()
|
||||
for i in range(10):
|
||||
episode.insert(Transition(
|
||||
state={'observation': np.array([i]), 'desired_goal': np.array([i+1]), 'achieved_goal': np.array([i+1])},
|
||||
action=i,
|
||||
game_over=i == 9,
|
||||
reward=0 if i == 9 else -1,
|
||||
))
|
||||
|
||||
her.store_episode(episode)
|
||||
# print('her._num_transitions', her._num_transitions)
|
||||
|
||||
# 10 original transitions, and 9 transitions * 4 hindsight episodes
|
||||
assert her.num_transitions() == 10 + (4 * 9)
|
||||
|
||||
# make sure that the goal state was never sampled from the past
|
||||
for transition in her.transitions:
|
||||
assert transition.state['desired_goal'] > transition.state['observation']
|
||||
assert transition.next_state['desired_goal'] >= transition.next_state['observation']
|
||||
|
||||
if transition.reward == 0:
|
||||
assert transition.game_over
|
||||
else:
|
||||
assert not transition.game_over
|
||||
|
||||
test_update_episode(her())
|
||||
Reference in New Issue
Block a user