diff --git a/rl_coach/core_types.py b/rl_coach/core_types.py index 9fa07ef..c173318 100644 --- a/rl_coach/core_types.py +++ b/rl_coach/core_types.py @@ -16,6 +16,7 @@ from collections import namedtuple import copy +import math from enum import Enum from random import shuffle from typing import List, Union, Dict, Any, Type @@ -63,6 +64,37 @@ class StepMethod(object): def num_steps(self, val: int) -> None: self._num_steps = val + def __eq__(self, other): + return self.num_steps == other.num_steps + + def __truediv__(self, other): + """ + divide this step method with other. If other is an integer, returns an object of the same + type as self. If other is the same type of self, returns an integer. In either case, any + floating point value is rounded up under the assumption that if we are dividing Steps, we + would rather overestimate than underestimate. + """ + if isinstance(other, type(self)): + return math.ceil(self.num_steps / other.num_steps) + elif isinstance(other, int): + return type(self)(math.ceil(self.num_steps / other)) + else: + raise TypeError("cannot divide {} by {}".format(type(self), type(other))) + + def __rtruediv__(self, other): + """ + divide this step method with other. If other is an integer, returns an object of the same + type as self. If other is the same type of self, returns an integer. In either case, any + floating point value is rounded up under the assumption that if we are dividing Steps, we + would rather overestimate than underestimate. + """ + if isinstance(other, type(self)): + return math.ceil(other.num_steps / self.num_steps) + elif isinstance(other, int): + return type(self)(math.ceil(other / self.num_steps)) + else: + raise TypeError("cannot divide {} by {}".format(type(other), type(self))) + class Frames(StepMethod): def __init__(self, num_steps): diff --git a/rl_coach/tests/test_core_types.py b/rl_coach/tests/test_core_types.py index 58df0ad..c9be955 100644 --- a/rl_coach/tests/test_core_types.py +++ b/rl_coach/tests/test_core_types.py @@ -1,4 +1,11 @@ -from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes +from rl_coach.core_types import ( + TotalStepsCounter, + EnvironmentSteps, + EnvironmentEpisodes, + StepMethod, + EnvironmentSteps, + EnvironmentEpisodes, +) import pytest @@ -25,3 +32,33 @@ def test_total_steps_counter_less_than(): assert not (counter < steps) steps = counter + EnvironmentSteps(1) assert counter < steps + + +@pytest.mark.unit_test +def test_step_method_div(): + assert StepMethod(10) / 2 == StepMethod(5) + assert StepMethod(10) / StepMethod(2) == 5 + + +@pytest.mark.unit_test +def test_step_method_div_ceil(): + assert StepMethod(10) / 3 == StepMethod(4) + assert StepMethod(10) / StepMethod(3) == 4 + + +@pytest.mark.unit_test +def test_step_method_rdiv_ceil(): + assert 10 / StepMethod(3) == StepMethod(4) + assert StepMethod(10) / StepMethod(3) == 4 + + +@pytest.mark.unit_test +def test_step_method_rdiv(): + assert 10 / StepMethod(2) == StepMethod(5) + assert StepMethod(10) / StepMethod(2) == 5 + + +@pytest.mark.unit_test +def test_step_method_div_type(): + with pytest.raises(TypeError): + EnvironmentEpisodes(10) / EnvironmentSteps(2)