1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00
Files
coach/rl_coach/tests/test_schedules.py
2018-08-13 17:11:34 +03:00

107 lines
3.2 KiB
Python

import os
import sys
from rl_coach.core_types import EnvironmentSteps
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import pytest
from rl_coach.schedules import LinearSchedule, ConstantSchedule, ExponentialSchedule, PieceWiseSchedule
import numpy as np
@pytest.mark.unit_test
def test_constant_schedule():
schedule = ConstantSchedule(0.3)
# make sure the values in the constant schedule don't change over time
for i in range(1000):
assert schedule.initial_value == 0.3
assert schedule.current_value == 0.3
schedule.step()
@pytest.mark.unit_test
def test_linear_schedule():
# increasing schedule
schedule = LinearSchedule(1, 3, 10)
# the schedule is defined in number of steps to get from 1 to 3 so there are 10 steps
# the linspace is defined in number of bins between 1 and 3 so theres are 11 bins
target_values = np.linspace(1, 3, 11)
for i in range(10):
# we round to 4 because there is a very small floating point division difference (1e-10)
assert round(schedule.current_value, 4) == round(target_values[i], 4)
schedule.step()
# make sure the value does not change after 10 steps
for i in range(10):
assert schedule.current_value == 3
# decreasing schedule
schedule = LinearSchedule(3, 1, 10)
target_values = np.linspace(3, 1, 11)
for i in range(10):
# we round to 4 because there is a very small floating point division difference (1e-10)
assert round(schedule.current_value, 4) == round(target_values[i], 4)
schedule.step()
# make sure the value does not change after 10 steps
for i in range(10):
assert schedule.current_value == 1
# constant schedule
schedule = LinearSchedule(3, 3, 10)
for i in range(10):
# we round to 4 because there is a very small floating point division difference (1e-10)
assert round(schedule.current_value, 4) == 3
schedule.step()
@pytest.mark.unit_test
def test_exponential_schedule():
# decreasing schedule
schedule = ExponentialSchedule(10, 3, 0.99)
current_power = 1
for i in range(100):
assert round(schedule.current_value,6) == round(10*current_power,6)
current_power *= 0.99
schedule.step()
for i in range(100):
schedule.step()
assert schedule.current_value == 3
@pytest.mark.unit_test
def test_piece_wise_schedule():
# decreasing schedule
schedule = PieceWiseSchedule(
[(LinearSchedule(1, 3, 10), EnvironmentSteps(5)),
(ConstantSchedule(4), EnvironmentSteps(10)),
(ExponentialSchedule(3, 1, 0.99), EnvironmentSteps(10))
]
)
target_values = np.append(np.linspace(1, 2, 6), np.ones(11)*4)
for i in range(16):
assert round(schedule.current_value, 4) == round(target_values[i], 4)
schedule.step()
current_power = 1
for i in range(10):
assert round(schedule.current_value, 4) == round(3*current_power, 4)
current_power *= 0.99
schedule.step()
if __name__ == "__main__":
test_constant_schedule()
test_linear_schedule()
test_exponential_schedule()
test_piece_wise_schedule()