1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-15 05:25:55 +01:00

more clear names for methods of Space (#181)

* rename Space.val_matches_space_definition -> contains; Space.is_point_in_space_shape -> valid_index
* rename valid_index -> is_valid_index
This commit is contained in:
Zach Dwiel
2019-01-14 15:02:53 -05:00
committed by GitHub
parent 0ccc333d77
commit cd812b0d25
19 changed files with 77 additions and 62 deletions

View File

@@ -280,7 +280,7 @@ class Environment(EnvironmentInterface):
:return: the environment response as returned in get_last_env_response
"""
action = self.action_space.clip_action_to_space(action)
if self.action_space and not self.action_space.val_matches_space_definition(action):
if self.action_space and not self.action_space.contains(action):
raise ValueError("The given action does not match the action space definition. "
"Action = {}, action space definition = {}".format(action, self.action_space))

View File

@@ -47,7 +47,7 @@ class ActionFilter(Filter):
:param action: an action to validate
:return: None
"""
if not self.output_action_space.val_matches_space_definition(action):
if not self.output_action_space.contains(action):
raise ValueError("The given action ({}) does not match the action space ({})"
.format(action, self.output_action_space))

View File

@@ -42,7 +42,7 @@ class PartialDiscreteActionSpaceMap(ActionFilter):
if not self.target_actions:
raise ValueError("The target actions were not set")
for v in self.target_actions:
if not output_action_space.val_matches_space_definition(v):
if not output_action_space.contains(v):
raise ValueError("The values in the output actions ({}) do not match the output action "
"space definition ({})".format(v, output_action_space))

View File

@@ -71,8 +71,8 @@ class ObservationCropFilter(ObservationFilter):
if np.any(crop_high > input_observation_space.shape) or \
np.any(crop_low > input_observation_space.shape):
raise ValueError("The cropping values are outside of the observation space")
if not input_observation_space.is_point_in_space_shape(crop_low) or \
not input_observation_space.is_point_in_space_shape(crop_high - 1):
if not input_observation_space.is_valid_index(crop_low) or \
not input_observation_space.is_valid_index(crop_high - 1):
raise ValueError("The cropping indices are outside of the observation space")
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:

View File

@@ -79,7 +79,7 @@ class ObservationStackingFilter(ObservationFilter):
"functionality")
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
if len(self.stack) > 0 and not input_observation_space.val_matches_space_definition(self.stack[-1]):
if len(self.stack) > 0 and not input_observation_space.contains(self.stack[-1]):
raise ValueError("The given input observation space is different than the observations already stored in"
"the filters memory")
if input_observation_space.num_dimensions <= self.stacking_axis:

View File

@@ -117,9 +117,10 @@ class Space(object):
if type(self._high) == int or type(self._high) == float:
self._high = np.ones(self.shape)*self._high
def val_matches_space_definition(self, val: Union[int, float, np.ndarray]) -> bool:
def contains(self, val: Union[int, float, np.ndarray]) -> bool:
"""
Checks if the given value matches the space definition in terms of shape and values
Checks if value is contained by this space. The shape must match and
all of the values must be within the low and high bounds.
:param val: a value to check
:return: True / False depending on if the val matches the space definition
@@ -134,16 +135,16 @@ class Space(object):
return False
return True
def is_point_in_space_shape(self, point: np.ndarray) -> bool:
def is_valid_index(self, index: np.ndarray) -> bool:
"""
Checks if a given multidimensional point is within the bounds of the shape of the space
Checks if a given multidimensional index is within the bounds of the shape of the space
:param point: a multidimensional point
:return: True if the point is within the shape of the space. False otherwise
:param index: a multidimensional index
:return: True if the index is within the shape of the space. False otherwise
"""
if len(point) != self.num_dimensions:
if len(index) != self.num_dimensions:
return False
if np.any(point < np.zeros(self.num_dimensions)) or np.any(point >= self.shape):
if np.any(index < np.zeros(self.num_dimensions)) or np.any(index >= self.shape):
return False
return True
@@ -160,6 +161,20 @@ class Space(object):
else:
return np.random.uniform(self.low, self.high, self.shape)
def val_matches_space_definition(self, val: Union[int, float, np.ndarray]) -> bool:
screen.warning(
"Space.val_matches_space_definition will be deprecated soon. Use "
"contains instead."
)
return self.contains(val)
def is_point_in_space_shape(self, point: np.ndarray) -> bool:
screen.warning(
"Space.is_point_in_space_shape will be deprecated soon. Use "
"is_valid_index instead."
)
return self.is_valid_index(point)
class RewardSpace(Space):
def __init__(self, shape: Union[int, np.ndarray], low: Union[None, int, float, np.ndarray]=-np.inf,

View File

@@ -32,7 +32,7 @@ def test_filter():
result = filter.filter(action)
assert np.all(result == np.array([[41.5, 0], [83., 41.5]]))
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)
# force int bins
filter = AttentionDiscretization(2, force_int_bins=True)

View File

@@ -26,7 +26,7 @@ def test_filter():
result = filter.filter(action)
assert result == [7.5]
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)
# 2 dimensional box
filter = BoxDiscretization(3)
@@ -42,4 +42,4 @@ def test_filter():
result = filter.filter(action)
assert result == [5., 15.]
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)

View File

@@ -23,5 +23,5 @@ def test_filter():
action = np.array([2])
result = filter.filter(action)
assert result == np.array([12])
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)

View File

@@ -25,5 +25,5 @@ def test_filter():
action = np.array([12])
result = filter.filter(action)
assert result == np.array([11])
assert output_space.val_matches_space_definition(result)
assert output_space.contains(result)

View File

@@ -132,18 +132,18 @@ def test_agent_selection():
def test_observation_space():
observation_space = ObservationSpace(np.array([1, 10]), -10, 10)
# testing that val_matches_space_definition works
assert observation_space.val_matches_space_definition(np.ones([1, 10]))
assert not observation_space.val_matches_space_definition(np.ones([2, 10]))
assert not observation_space.val_matches_space_definition(np.ones([1, 10]) * 100)
assert not observation_space.val_matches_space_definition(np.ones([1, 1, 10]))
# testing that contains works
assert observation_space.contains(np.ones([1, 10]))
assert not observation_space.contains(np.ones([2, 10]))
assert not observation_space.contains(np.ones([1, 10]) * 100)
assert not observation_space.contains(np.ones([1, 1, 10]))
# is_point_in_space_shape
assert observation_space.is_point_in_space_shape(np.array([0, 9]))
assert observation_space.is_point_in_space_shape(np.array([0, 0]))
assert not observation_space.is_point_in_space_shape(np.array([1, 8]))
assert not observation_space.is_point_in_space_shape(np.array([0, 10]))
assert not observation_space.is_point_in_space_shape(np.array([-1, 6]))
# is_valid_index
assert observation_space.is_valid_index(np.array([0, 9]))
assert observation_space.is_valid_index(np.array([0, 0]))
assert not observation_space.is_valid_index(np.array([1, 8]))
assert not observation_space.is_valid_index(np.array([0, 10]))
assert not observation_space.is_valid_index(np.array([-1, 6]))
@pytest.mark.unit_test