1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

add __truediv__, __rtruediv__ and __eq__ to StepMethod

This commit is contained in:
zach dwiel
2019-04-04 16:11:07 -04:00
committed by Zach Dwiel
parent 83da5cde2f
commit 2cb078b4c2
2 changed files with 70 additions and 1 deletions

View File

@@ -16,6 +16,7 @@
from collections import namedtuple from collections import namedtuple
import copy import copy
import math
from enum import Enum from enum import Enum
from random import shuffle from random import shuffle
from typing import List, Union, Dict, Any, Type from typing import List, Union, Dict, Any, Type
@@ -63,6 +64,37 @@ class StepMethod(object):
def num_steps(self, val: int) -> None: def num_steps(self, val: int) -> None:
self._num_steps = val self._num_steps = val
def __eq__(self, other):
return self.num_steps == other.num_steps
def __truediv__(self, other):
"""
divide this step method with other. If other is an integer, returns an object of the same
type as self. If other is the same type of self, returns an integer. In either case, any
floating point value is rounded up under the assumption that if we are dividing Steps, we
would rather overestimate than underestimate.
"""
if isinstance(other, type(self)):
return math.ceil(self.num_steps / other.num_steps)
elif isinstance(other, int):
return type(self)(math.ceil(self.num_steps / other))
else:
raise TypeError("cannot divide {} by {}".format(type(self), type(other)))
def __rtruediv__(self, other):
"""
divide this step method with other. If other is an integer, returns an object of the same
type as self. If other is the same type of self, returns an integer. In either case, any
floating point value is rounded up under the assumption that if we are dividing Steps, we
would rather overestimate than underestimate.
"""
if isinstance(other, type(self)):
return math.ceil(other.num_steps / self.num_steps)
elif isinstance(other, int):
return type(self)(math.ceil(other / self.num_steps))
else:
raise TypeError("cannot divide {} by {}".format(type(other), type(self)))
class Frames(StepMethod): class Frames(StepMethod):
def __init__(self, num_steps): def __init__(self, num_steps):

View File

@@ -1,4 +1,11 @@
from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes from rl_coach.core_types import (
TotalStepsCounter,
EnvironmentSteps,
EnvironmentEpisodes,
StepMethod,
EnvironmentSteps,
EnvironmentEpisodes,
)
import pytest import pytest
@@ -25,3 +32,33 @@ def test_total_steps_counter_less_than():
assert not (counter < steps) assert not (counter < steps)
steps = counter + EnvironmentSteps(1) steps = counter + EnvironmentSteps(1)
assert counter < steps assert counter < steps
@pytest.mark.unit_test
def test_step_method_div():
assert StepMethod(10) / 2 == StepMethod(5)
assert StepMethod(10) / StepMethod(2) == 5
@pytest.mark.unit_test
def test_step_method_div_ceil():
assert StepMethod(10) / 3 == StepMethod(4)
assert StepMethod(10) / StepMethod(3) == 4
@pytest.mark.unit_test
def test_step_method_rdiv_ceil():
assert 10 / StepMethod(3) == StepMethod(4)
assert StepMethod(10) / StepMethod(3) == 4
@pytest.mark.unit_test
def test_step_method_rdiv():
assert 10 / StepMethod(2) == StepMethod(5)
assert StepMethod(10) / StepMethod(2) == 5
@pytest.mark.unit_test
def test_step_method_div_type():
with pytest.raises(TypeError):
EnvironmentEpisodes(10) / EnvironmentSteps(2)