mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
pre-release 0.10.0
This commit is contained in:
34
rl_coach/tests/exploration_policies/test_greedy.py
Normal file
34
rl_coach/tests/exploration_policies/test_greedy.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
import pytest
|
||||
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
from rl_coach.exploration_policies.greedy import Greedy
|
||||
import numpy as np
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_action():
|
||||
# discrete control
|
||||
action_space = DiscreteActionSpace(3)
|
||||
policy = Greedy(action_space)
|
||||
|
||||
best_action = policy.get_action(np.array([10, 20, 30]))
|
||||
assert best_action == 2
|
||||
|
||||
# continuous control
|
||||
action_space = BoxActionSpace(np.array([10]))
|
||||
policy = Greedy(action_space)
|
||||
|
||||
best_action = policy.get_action(np.array([1, 1, 1]))
|
||||
assert np.all(best_action == np.array([1, 1, 1]))
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_control_param():
|
||||
action_space = DiscreteActionSpace(3)
|
||||
policy = Greedy(action_space)
|
||||
assert policy.get_control_param() == 0
|
||||
|
||||
Reference in New Issue
Block a user