mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
add __truediv__, __rtruediv__ and __eq__ to StepMethod
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import copy
|
||||
import math
|
||||
from enum import Enum
|
||||
from random import shuffle
|
||||
from typing import List, Union, Dict, Any, Type
|
||||
@@ -63,6 +64,37 @@ class StepMethod(object):
|
||||
def num_steps(self, val: int) -> None:
|
||||
self._num_steps = val
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.num_steps == other.num_steps
|
||||
|
||||
def __truediv__(self, other):
|
||||
"""
|
||||
divide this step method with other. If other is an integer, returns an object of the same
|
||||
type as self. If other is the same type of self, returns an integer. In either case, any
|
||||
floating point value is rounded up under the assumption that if we are dividing Steps, we
|
||||
would rather overestimate than underestimate.
|
||||
"""
|
||||
if isinstance(other, type(self)):
|
||||
return math.ceil(self.num_steps / other.num_steps)
|
||||
elif isinstance(other, int):
|
||||
return type(self)(math.ceil(self.num_steps / other))
|
||||
else:
|
||||
raise TypeError("cannot divide {} by {}".format(type(self), type(other)))
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
"""
|
||||
divide this step method with other. If other is an integer, returns an object of the same
|
||||
type as self. If other is the same type of self, returns an integer. In either case, any
|
||||
floating point value is rounded up under the assumption that if we are dividing Steps, we
|
||||
would rather overestimate than underestimate.
|
||||
"""
|
||||
if isinstance(other, type(self)):
|
||||
return math.ceil(other.num_steps / self.num_steps)
|
||||
elif isinstance(other, int):
|
||||
return type(self)(math.ceil(other / self.num_steps))
|
||||
else:
|
||||
raise TypeError("cannot divide {} by {}".format(type(other), type(self)))
|
||||
|
||||
|
||||
class Frames(StepMethod):
|
||||
def __init__(self, num_steps):
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes
|
||||
from rl_coach.core_types import (
|
||||
TotalStepsCounter,
|
||||
EnvironmentSteps,
|
||||
EnvironmentEpisodes,
|
||||
StepMethod,
|
||||
EnvironmentSteps,
|
||||
EnvironmentEpisodes,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -25,3 +32,33 @@ def test_total_steps_counter_less_than():
|
||||
assert not (counter < steps)
|
||||
steps = counter + EnvironmentSteps(1)
|
||||
assert counter < steps
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_step_method_div():
|
||||
assert StepMethod(10) / 2 == StepMethod(5)
|
||||
assert StepMethod(10) / StepMethod(2) == 5
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_step_method_div_ceil():
|
||||
assert StepMethod(10) / 3 == StepMethod(4)
|
||||
assert StepMethod(10) / StepMethod(3) == 4
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_step_method_rdiv_ceil():
|
||||
assert 10 / StepMethod(3) == StepMethod(4)
|
||||
assert StepMethod(10) / StepMethod(3) == 4
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_step_method_rdiv():
|
||||
assert 10 / StepMethod(2) == StepMethod(5)
|
||||
assert StepMethod(10) / StepMethod(2) == 5
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_step_method_div_type():
|
||||
with pytest.raises(TypeError):
|
||||
EnvironmentEpisodes(10) / EnvironmentSteps(2)
|
||||
|
||||
Reference in New Issue
Block a user