1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

more clear names for methods of Space (#181)

* rename Space.val_matches_space_definition -> contains; Space.is_point_in_space_shape -> valid_index
* rename valid_index -> is_valid_index
This commit is contained in:
Zach Dwiel
2019-01-14 15:02:53 -05:00
committed by GitHub
parent 0ccc333d77
commit cd812b0d25
19 changed files with 77 additions and 62 deletions

View File

@@ -117,9 +117,10 @@ class Space(object):
if type(self._high) == int or type(self._high) == float:
self._high = np.ones(self.shape)*self._high
def val_matches_space_definition(self, val: Union[int, float, np.ndarray]) -> bool:
def contains(self, val: Union[int, float, np.ndarray]) -> bool:
"""
Checks if the given value matches the space definition in terms of shape and values
Checks if value is contained by this space. The shape must match and
all of the values must be within the low and high bounds.
:param val: a value to check
:return: True / False depending on if the val matches the space definition
@@ -134,16 +135,16 @@ class Space(object):
return False
return True
def is_point_in_space_shape(self, point: np.ndarray) -> bool:
def is_valid_index(self, index: np.ndarray) -> bool:
"""
Checks if a given multidimensional point is within the bounds of the shape of the space
Checks if a given multidimensional index is within the bounds of the shape of the space
:param point: a multidimensional point
:return: True if the point is within the shape of the space. False otherwise
:param index: a multidimensional index
:return: True if the index is within the shape of the space. False otherwise
"""
if len(point) != self.num_dimensions:
if len(index) != self.num_dimensions:
return False
if np.any(point < np.zeros(self.num_dimensions)) or np.any(point >= self.shape):
if np.any(index < np.zeros(self.num_dimensions)) or np.any(index >= self.shape):
return False
return True
@@ -160,6 +161,20 @@ class Space(object):
else:
return np.random.uniform(self.low, self.high, self.shape)
def val_matches_space_definition(self, val: Union[int, float, np.ndarray]) -> bool:
screen.warning(
"Space.val_matches_space_definition will be deprecated soon. Use "
"contains instead."
)
return self.contains(val)
def is_point_in_space_shape(self, point: np.ndarray) -> bool:
screen.warning(
"Space.is_point_in_space_shape will be deprecated soon. Use "
"is_valid_index instead."
)
return self.is_valid_index(point)
class RewardSpace(Space):
def __init__(self, shape: Union[int, np.ndarray], low: Union[None, int, float, np.ndarray]=-np.inf,