import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import pytest from rl_coach.spaces import DiscreteActionSpace from rl_coach.exploration_policies.e_greedy import EGreedy from rl_coach.schedules import LinearSchedule import numpy as np from rl_coach.core_types import RunPhase @pytest.mark.unit_test def test_get_action(): # discrete control action_space = DiscreteActionSpace(3) epsilon_schedule = LinearSchedule(1.0, 1.0, 1000) policy = EGreedy(action_space, epsilon_schedule, evaluation_epsilon=0) # verify that test phase gives greedy actions (evaluation_epsilon = 0) policy.change_phase(RunPhase.TEST) for i in range(100): best_action, _ = policy.get_action(np.array([10, 20, 30])) assert best_action == 2 # verify that train phase gives uniform actions (exploration = 1) policy.change_phase(RunPhase.TRAIN) counters = np.array([0, 0, 0]) for i in range(30000): best_action, _ = policy.get_action(np.array([10, 20, 30])) counters[best_action] += 1 assert np.all(counters > 9500) # this is noisy so we allow 5% error # TODO: test continuous actions @pytest.mark.unit_test def test_change_phase(): # discrete control action_space = DiscreteActionSpace(3) epsilon_schedule = LinearSchedule(1.0, 0.1, 1000) policy = EGreedy(action_space, epsilon_schedule, evaluation_epsilon=0.01) # verify schedule not applying if not in training phase assert policy.get_control_param() == 1.0 policy.change_phase(RunPhase.TEST) best_action = policy.get_action(np.array([10, 20, 30])) assert policy.epsilon_schedule.current_value == 1.0 policy.change_phase(RunPhase.HEATUP) best_action = policy.get_action(np.array([10, 20, 30])) assert policy.epsilon_schedule.current_value == 1.0 policy.change_phase(RunPhase.UNDEFINED) best_action = policy.get_action(np.array([10, 20, 30])) assert policy.epsilon_schedule.current_value == 1.0 @pytest.mark.unit_test def test_get_control_param(): # discrete control action_space = DiscreteActionSpace(3) epsilon_schedule = LinearSchedule(1.0, 0.1, 1000) policy = EGreedy(action_space, epsilon_schedule, evaluation_epsilon=0.01) # verify schedule applies to TRAIN phase policy.change_phase(RunPhase.TRAIN) for i in range(999): best_action = policy.get_action(np.array([10, 20, 30])) assert 1.0 > policy.get_control_param() > 0.1 best_action = policy.get_action(np.array([10, 20, 30])) assert policy.get_control_param() == 0.1 # test phases policy.change_phase(RunPhase.TEST) assert policy.get_control_param() == 0.01 policy.change_phase(RunPhase.TRAIN) assert policy.get_control_param() == 0.1 policy.change_phase(RunPhase.HEATUP) assert policy.get_control_param() == 0.1