mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
add __truediv__, __rtruediv__ and __eq__ to StepMethod
This commit is contained in:
@@ -16,6 +16,7 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
from typing import List, Union, Dict, Any, Type
|
from typing import List, Union, Dict, Any, Type
|
||||||
@@ -63,6 +64,37 @@ class StepMethod(object):
|
|||||||
def num_steps(self, val: int) -> None:
|
def num_steps(self, val: int) -> None:
|
||||||
self._num_steps = val
|
self._num_steps = val
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.num_steps == other.num_steps
|
||||||
|
|
||||||
|
def __truediv__(self, other):
|
||||||
|
"""
|
||||||
|
divide this step method with other. If other is an integer, returns an object of the same
|
||||||
|
type as self. If other is the same type of self, returns an integer. In either case, any
|
||||||
|
floating point value is rounded up under the assumption that if we are dividing Steps, we
|
||||||
|
would rather overestimate than underestimate.
|
||||||
|
"""
|
||||||
|
if isinstance(other, type(self)):
|
||||||
|
return math.ceil(self.num_steps / other.num_steps)
|
||||||
|
elif isinstance(other, int):
|
||||||
|
return type(self)(math.ceil(self.num_steps / other))
|
||||||
|
else:
|
||||||
|
raise TypeError("cannot divide {} by {}".format(type(self), type(other)))
|
||||||
|
|
||||||
|
def __rtruediv__(self, other):
|
||||||
|
"""
|
||||||
|
divide this step method with other. If other is an integer, returns an object of the same
|
||||||
|
type as self. If other is the same type of self, returns an integer. In either case, any
|
||||||
|
floating point value is rounded up under the assumption that if we are dividing Steps, we
|
||||||
|
would rather overestimate than underestimate.
|
||||||
|
"""
|
||||||
|
if isinstance(other, type(self)):
|
||||||
|
return math.ceil(other.num_steps / self.num_steps)
|
||||||
|
elif isinstance(other, int):
|
||||||
|
return type(self)(math.ceil(other / self.num_steps))
|
||||||
|
else:
|
||||||
|
raise TypeError("cannot divide {} by {}".format(type(other), type(self)))
|
||||||
|
|
||||||
|
|
||||||
class Frames(StepMethod):
|
class Frames(StepMethod):
|
||||||
def __init__(self, num_steps):
|
def __init__(self, num_steps):
|
||||||
|
|||||||
@@ -1,4 +1,11 @@
|
|||||||
from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes
|
from rl_coach.core_types import (
|
||||||
|
TotalStepsCounter,
|
||||||
|
EnvironmentSteps,
|
||||||
|
EnvironmentEpisodes,
|
||||||
|
StepMethod,
|
||||||
|
EnvironmentSteps,
|
||||||
|
EnvironmentEpisodes,
|
||||||
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -25,3 +32,33 @@ def test_total_steps_counter_less_than():
|
|||||||
assert not (counter < steps)
|
assert not (counter < steps)
|
||||||
steps = counter + EnvironmentSteps(1)
|
steps = counter + EnvironmentSteps(1)
|
||||||
assert counter < steps
|
assert counter < steps
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_step_method_div():
|
||||||
|
assert StepMethod(10) / 2 == StepMethod(5)
|
||||||
|
assert StepMethod(10) / StepMethod(2) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_step_method_div_ceil():
|
||||||
|
assert StepMethod(10) / 3 == StepMethod(4)
|
||||||
|
assert StepMethod(10) / StepMethod(3) == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_step_method_rdiv_ceil():
|
||||||
|
assert 10 / StepMethod(3) == StepMethod(4)
|
||||||
|
assert StepMethod(10) / StepMethod(3) == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_step_method_rdiv():
|
||||||
|
assert 10 / StepMethod(2) == StepMethod(5)
|
||||||
|
assert StepMethod(10) / StepMethod(2) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_step_method_div_type():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
EnvironmentEpisodes(10) / EnvironmentSteps(2)
|
||||||
|
|||||||
Reference in New Issue
Block a user