mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
pre-release 0.10.0
This commit is contained in:
106
rl_coach/tests/test_schedules.py
Normal file
106
rl_coach/tests/test_schedules.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
import pytest
|
||||
|
||||
from rl_coach.schedules import LinearSchedule, ConstantSchedule, ExponentialSchedule, PieceWiseSchedule
|
||||
import numpy as np
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_constant_schedule():
|
||||
schedule = ConstantSchedule(0.3)
|
||||
|
||||
# make sure the values in the constant schedule don't change over time
|
||||
for i in range(1000):
|
||||
assert schedule.initial_value == 0.3
|
||||
assert schedule.current_value == 0.3
|
||||
schedule.step()
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_linear_schedule():
|
||||
# increasing schedule
|
||||
schedule = LinearSchedule(1, 3, 10)
|
||||
|
||||
# the schedule is defined in number of steps to get from 1 to 3 so there are 10 steps
|
||||
# the linspace is defined in number of bins between 1 and 3 so theres are 11 bins
|
||||
target_values = np.linspace(1, 3, 11)
|
||||
for i in range(10):
|
||||
# we round to 4 because there is a very small floating point division difference (1e-10)
|
||||
assert round(schedule.current_value, 4) == round(target_values[i], 4)
|
||||
schedule.step()
|
||||
|
||||
# make sure the value does not change after 10 steps
|
||||
for i in range(10):
|
||||
assert schedule.current_value == 3
|
||||
|
||||
# decreasing schedule
|
||||
schedule = LinearSchedule(3, 1, 10)
|
||||
|
||||
target_values = np.linspace(3, 1, 11)
|
||||
for i in range(10):
|
||||
# we round to 4 because there is a very small floating point division difference (1e-10)
|
||||
assert round(schedule.current_value, 4) == round(target_values[i], 4)
|
||||
schedule.step()
|
||||
|
||||
# make sure the value does not change after 10 steps
|
||||
for i in range(10):
|
||||
assert schedule.current_value == 1
|
||||
|
||||
# constant schedule
|
||||
schedule = LinearSchedule(3, 3, 10)
|
||||
|
||||
for i in range(10):
|
||||
# we round to 4 because there is a very small floating point division difference (1e-10)
|
||||
assert round(schedule.current_value, 4) == 3
|
||||
schedule.step()
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_exponential_schedule():
|
||||
# decreasing schedule
|
||||
schedule = ExponentialSchedule(10, 3, 0.99)
|
||||
|
||||
current_power = 1
|
||||
for i in range(100):
|
||||
assert round(schedule.current_value,6) == round(10*current_power,6)
|
||||
current_power *= 0.99
|
||||
schedule.step()
|
||||
|
||||
for i in range(100):
|
||||
schedule.step()
|
||||
assert schedule.current_value == 3
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_piece_wise_schedule():
|
||||
# decreasing schedule
|
||||
schedule = PieceWiseSchedule(
|
||||
[(LinearSchedule(1, 3, 10), EnvironmentSteps(5)),
|
||||
(ConstantSchedule(4), EnvironmentSteps(10)),
|
||||
(ExponentialSchedule(3, 1, 0.99), EnvironmentSteps(10))
|
||||
]
|
||||
)
|
||||
|
||||
target_values = np.append(np.linspace(1, 2, 6), np.ones(11)*4)
|
||||
for i in range(16):
|
||||
assert round(schedule.current_value, 4) == round(target_values[i], 4)
|
||||
schedule.step()
|
||||
|
||||
current_power = 1
|
||||
for i in range(10):
|
||||
assert round(schedule.current_value, 4) == round(3*current_power, 4)
|
||||
current_power *= 0.99
|
||||
schedule.step()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_constant_schedule()
|
||||
test_linear_schedule()
|
||||
test_exponential_schedule()
|
||||
test_piece_wise_schedule()
|
||||
Reference in New Issue
Block a user