mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
import os
|
|
import sys
|
|
|
|
from rl_coach.core_types import EnvironmentSteps
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
|
|
import pytest
|
|
|
|
from rl_coach.schedules import LinearSchedule, ConstantSchedule, ExponentialSchedule, PieceWiseSchedule
|
|
import numpy as np
|
|
|
|
|
|
@pytest.mark.unit_test
|
|
def test_constant_schedule():
|
|
schedule = ConstantSchedule(0.3)
|
|
|
|
# make sure the values in the constant schedule don't change over time
|
|
for i in range(1000):
|
|
assert schedule.initial_value == 0.3
|
|
assert schedule.current_value == 0.3
|
|
schedule.step()
|
|
|
|
|
|
@pytest.mark.unit_test
|
|
def test_linear_schedule():
|
|
# increasing schedule
|
|
schedule = LinearSchedule(1, 3, 10)
|
|
|
|
# the schedule is defined in number of steps to get from 1 to 3 so there are 10 steps
|
|
# the linspace is defined in number of bins between 1 and 3 so theres are 11 bins
|
|
target_values = np.linspace(1, 3, 11)
|
|
for i in range(10):
|
|
# we round to 4 because there is a very small floating point division difference (1e-10)
|
|
assert round(schedule.current_value, 4) == round(target_values[i], 4)
|
|
schedule.step()
|
|
|
|
# make sure the value does not change after 10 steps
|
|
for i in range(10):
|
|
assert schedule.current_value == 3
|
|
|
|
# decreasing schedule
|
|
schedule = LinearSchedule(3, 1, 10)
|
|
|
|
target_values = np.linspace(3, 1, 11)
|
|
for i in range(10):
|
|
# we round to 4 because there is a very small floating point division difference (1e-10)
|
|
assert round(schedule.current_value, 4) == round(target_values[i], 4)
|
|
schedule.step()
|
|
|
|
# make sure the value does not change after 10 steps
|
|
for i in range(10):
|
|
assert schedule.current_value == 1
|
|
|
|
# constant schedule
|
|
schedule = LinearSchedule(3, 3, 10)
|
|
|
|
for i in range(10):
|
|
# we round to 4 because there is a very small floating point division difference (1e-10)
|
|
assert round(schedule.current_value, 4) == 3
|
|
schedule.step()
|
|
|
|
|
|
@pytest.mark.unit_test
|
|
def test_exponential_schedule():
|
|
# decreasing schedule
|
|
schedule = ExponentialSchedule(10, 3, 0.99)
|
|
|
|
current_power = 1
|
|
for i in range(100):
|
|
assert round(schedule.current_value,6) == round(10*current_power,6)
|
|
current_power *= 0.99
|
|
schedule.step()
|
|
|
|
for i in range(100):
|
|
schedule.step()
|
|
assert schedule.current_value == 3
|
|
|
|
|
|
@pytest.mark.unit_test
|
|
def test_piece_wise_schedule():
|
|
# decreasing schedule
|
|
schedule = PieceWiseSchedule(
|
|
[(LinearSchedule(1, 3, 10), EnvironmentSteps(5)),
|
|
(ConstantSchedule(4), EnvironmentSteps(10)),
|
|
(ExponentialSchedule(3, 1, 0.99), EnvironmentSteps(10))
|
|
]
|
|
)
|
|
|
|
target_values = np.append(np.linspace(1, 2, 6), np.ones(11)*4)
|
|
for i in range(16):
|
|
assert round(schedule.current_value, 4) == round(target_values[i], 4)
|
|
schedule.step()
|
|
|
|
current_power = 1
|
|
for i in range(10):
|
|
assert round(schedule.current_value, 4) == round(3*current_power, 4)
|
|
current_power *= 0.99
|
|
schedule.step()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_constant_schedule()
|
|
test_linear_schedule()
|
|
test_exponential_schedule()
|
|
test_piece_wise_schedule()
|