mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
35 lines
889 B
Python
35 lines
889 B
Python
import os
|
|
import sys
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
|
|
import pytest
|
|
|
|
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
|
from rl_coach.exploration_policies.greedy import Greedy
|
|
import numpy as np
|
|
|
|
|
|
@pytest.mark.unit_test
|
|
def test_get_action():
|
|
# discrete control
|
|
action_space = DiscreteActionSpace(3)
|
|
policy = Greedy(action_space)
|
|
|
|
best_action, _ = policy.get_action(np.array([10, 20, 30]))
|
|
assert best_action == 2
|
|
|
|
# continuous control
|
|
action_space = BoxActionSpace(np.array([10]))
|
|
policy = Greedy(action_space)
|
|
|
|
best_action = policy.get_action(np.array([1, 1, 1]))
|
|
assert np.all(best_action == np.array([1, 1, 1]))
|
|
|
|
|
|
@pytest.mark.unit_test
|
|
def test_get_control_param():
|
|
action_space = DiscreteActionSpace(3)
|
|
policy = Greedy(action_space)
|
|
assert policy.get_control_param() == 0
|
|
|