From 2cb078b4c2dcbbdae48d5553855dcb08e13996be Mon Sep 17 00:00:00 2001 From: zach dwiel Date: Thu, 4 Apr 2019 16:11:07 -0400 Subject: [PATCH] add __truediv__, __rtruediv__ and __eq__ to StepMethod --- rl_coach/core_types.py | 32 +++++++++++++++++++++++++ rl_coach/tests/test_core_types.py | 39 ++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/rl_coach/core_types.py b/rl_coach/core_types.py index 9fa07ef..c173318 100644 --- a/rl_coach/core_types.py +++ b/rl_coach/core_types.py @@ -16,6 +16,7 @@ from collections import namedtuple import copy +import math from enum import Enum from random import shuffle from typing import List, Union, Dict, Any, Type @@ -63,6 +64,37 @@ class StepMethod(object): def num_steps(self, val: int) -> None: self._num_steps = val + def __eq__(self, other): + return self.num_steps == other.num_steps + + def __truediv__(self, other): + """ + divide this step method with other. If other is an integer, returns an object of the same + type as self. If other is the same type of self, returns an integer. In either case, any + floating point value is rounded up under the assumption that if we are dividing Steps, we + would rather overestimate than underestimate. + """ + if isinstance(other, type(self)): + return math.ceil(self.num_steps / other.num_steps) + elif isinstance(other, int): + return type(self)(math.ceil(self.num_steps / other)) + else: + raise TypeError("cannot divide {} by {}".format(type(self), type(other))) + + def __rtruediv__(self, other): + """ + divide this step method with other. If other is an integer, returns an object of the same + type as self. If other is the same type of self, returns an integer. In either case, any + floating point value is rounded up under the assumption that if we are dividing Steps, we + would rather overestimate than underestimate. + """ + if isinstance(other, type(self)): + return math.ceil(other.num_steps / self.num_steps) + elif isinstance(other, int): + return type(self)(math.ceil(other / self.num_steps)) + else: + raise TypeError("cannot divide {} by {}".format(type(other), type(self))) + class Frames(StepMethod): def __init__(self, num_steps): diff --git a/rl_coach/tests/test_core_types.py b/rl_coach/tests/test_core_types.py index 58df0ad..c9be955 100644 --- a/rl_coach/tests/test_core_types.py +++ b/rl_coach/tests/test_core_types.py @@ -1,4 +1,11 @@ -from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes +from rl_coach.core_types import ( + TotalStepsCounter, + EnvironmentSteps, + EnvironmentEpisodes, + StepMethod, + EnvironmentSteps, + EnvironmentEpisodes, +) import pytest @@ -25,3 +32,33 @@ def test_total_steps_counter_less_than(): assert not (counter < steps) steps = counter + EnvironmentSteps(1) assert counter < steps + + +@pytest.mark.unit_test +def test_step_method_div(): + assert StepMethod(10) / 2 == StepMethod(5) + assert StepMethod(10) / StepMethod(2) == 5 + + +@pytest.mark.unit_test +def test_step_method_div_ceil(): + assert StepMethod(10) / 3 == StepMethod(4) + assert StepMethod(10) / StepMethod(3) == 4 + + +@pytest.mark.unit_test +def test_step_method_rdiv_ceil(): + assert 10 / StepMethod(3) == StepMethod(4) + assert StepMethod(10) / StepMethod(3) == 4 + + +@pytest.mark.unit_test +def test_step_method_rdiv(): + assert 10 / StepMethod(2) == StepMethod(5) + assert StepMethod(10) / StepMethod(2) == 5 + + +@pytest.mark.unit_test +def test_step_method_div_type(): + with pytest.raises(TypeError): + EnvironmentEpisodes(10) / EnvironmentSteps(2)