mirror of
https://github.com/gryf/coach.git
synced 2026-02-15 05:25:55 +01:00
more clear names for methods of Space (#181)
* rename Space.val_matches_space_definition -> contains; Space.is_point_in_space_shape -> valid_index * rename valid_index -> is_valid_index
This commit is contained in:
@@ -280,7 +280,7 @@ class Environment(EnvironmentInterface):
|
||||
:return: the environment response as returned in get_last_env_response
|
||||
"""
|
||||
action = self.action_space.clip_action_to_space(action)
|
||||
if self.action_space and not self.action_space.val_matches_space_definition(action):
|
||||
if self.action_space and not self.action_space.contains(action):
|
||||
raise ValueError("The given action does not match the action space definition. "
|
||||
"Action = {}, action space definition = {}".format(action, self.action_space))
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class ActionFilter(Filter):
|
||||
:param action: an action to validate
|
||||
:return: None
|
||||
"""
|
||||
if not self.output_action_space.val_matches_space_definition(action):
|
||||
if not self.output_action_space.contains(action):
|
||||
raise ValueError("The given action ({}) does not match the action space ({})"
|
||||
.format(action, self.output_action_space))
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class PartialDiscreteActionSpaceMap(ActionFilter):
|
||||
if not self.target_actions:
|
||||
raise ValueError("The target actions were not set")
|
||||
for v in self.target_actions:
|
||||
if not output_action_space.val_matches_space_definition(v):
|
||||
if not output_action_space.contains(v):
|
||||
raise ValueError("The values in the output actions ({}) do not match the output action "
|
||||
"space definition ({})".format(v, output_action_space))
|
||||
|
||||
|
||||
@@ -71,8 +71,8 @@ class ObservationCropFilter(ObservationFilter):
|
||||
if np.any(crop_high > input_observation_space.shape) or \
|
||||
np.any(crop_low > input_observation_space.shape):
|
||||
raise ValueError("The cropping values are outside of the observation space")
|
||||
if not input_observation_space.is_point_in_space_shape(crop_low) or \
|
||||
not input_observation_space.is_point_in_space_shape(crop_high - 1):
|
||||
if not input_observation_space.is_valid_index(crop_low) or \
|
||||
not input_observation_space.is_valid_index(crop_high - 1):
|
||||
raise ValueError("The cropping indices are outside of the observation space")
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
|
||||
@@ -79,7 +79,7 @@ class ObservationStackingFilter(ObservationFilter):
|
||||
"functionality")
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
if len(self.stack) > 0 and not input_observation_space.val_matches_space_definition(self.stack[-1]):
|
||||
if len(self.stack) > 0 and not input_observation_space.contains(self.stack[-1]):
|
||||
raise ValueError("The given input observation space is different than the observations already stored in"
|
||||
"the filters memory")
|
||||
if input_observation_space.num_dimensions <= self.stacking_axis:
|
||||
|
||||
@@ -117,9 +117,10 @@ class Space(object):
|
||||
if type(self._high) == int or type(self._high) == float:
|
||||
self._high = np.ones(self.shape)*self._high
|
||||
|
||||
def val_matches_space_definition(self, val: Union[int, float, np.ndarray]) -> bool:
|
||||
def contains(self, val: Union[int, float, np.ndarray]) -> bool:
|
||||
"""
|
||||
Checks if the given value matches the space definition in terms of shape and values
|
||||
Checks if value is contained by this space. The shape must match and
|
||||
all of the values must be within the low and high bounds.
|
||||
|
||||
:param val: a value to check
|
||||
:return: True / False depending on if the val matches the space definition
|
||||
@@ -134,16 +135,16 @@ class Space(object):
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_point_in_space_shape(self, point: np.ndarray) -> bool:
|
||||
def is_valid_index(self, index: np.ndarray) -> bool:
|
||||
"""
|
||||
Checks if a given multidimensional point is within the bounds of the shape of the space
|
||||
Checks if a given multidimensional index is within the bounds of the shape of the space
|
||||
|
||||
:param point: a multidimensional point
|
||||
:return: True if the point is within the shape of the space. False otherwise
|
||||
:param index: a multidimensional index
|
||||
:return: True if the index is within the shape of the space. False otherwise
|
||||
"""
|
||||
if len(point) != self.num_dimensions:
|
||||
if len(index) != self.num_dimensions:
|
||||
return False
|
||||
if np.any(point < np.zeros(self.num_dimensions)) or np.any(point >= self.shape):
|
||||
if np.any(index < np.zeros(self.num_dimensions)) or np.any(index >= self.shape):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -160,6 +161,20 @@ class Space(object):
|
||||
else:
|
||||
return np.random.uniform(self.low, self.high, self.shape)
|
||||
|
||||
def val_matches_space_definition(self, val: Union[int, float, np.ndarray]) -> bool:
|
||||
screen.warning(
|
||||
"Space.val_matches_space_definition will be deprecated soon. Use "
|
||||
"contains instead."
|
||||
)
|
||||
return self.contains(val)
|
||||
|
||||
def is_point_in_space_shape(self, point: np.ndarray) -> bool:
|
||||
screen.warning(
|
||||
"Space.is_point_in_space_shape will be deprecated soon. Use "
|
||||
"is_valid_index instead."
|
||||
)
|
||||
return self.is_valid_index(point)
|
||||
|
||||
|
||||
class RewardSpace(Space):
|
||||
def __init__(self, shape: Union[int, np.ndarray], low: Union[None, int, float, np.ndarray]=-np.inf,
|
||||
|
||||
@@ -32,7 +32,7 @@ def test_filter():
|
||||
|
||||
result = filter.filter(action)
|
||||
assert np.all(result == np.array([[41.5, 0], [83., 41.5]]))
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
# force int bins
|
||||
filter = AttentionDiscretization(2, force_int_bins=True)
|
||||
|
||||
@@ -26,7 +26,7 @@ def test_filter():
|
||||
|
||||
result = filter.filter(action)
|
||||
assert result == [7.5]
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
# 2 dimensional box
|
||||
filter = BoxDiscretization(3)
|
||||
@@ -42,4 +42,4 @@ def test_filter():
|
||||
|
||||
result = filter.filter(action)
|
||||
assert result == [5., 15.]
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
@@ -23,5 +23,5 @@ def test_filter():
|
||||
action = np.array([2])
|
||||
result = filter.filter(action)
|
||||
assert result == np.array([12])
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
|
||||
@@ -25,5 +25,5 @@ def test_filter():
|
||||
action = np.array([12])
|
||||
result = filter.filter(action)
|
||||
assert result == np.array([11])
|
||||
assert output_space.val_matches_space_definition(result)
|
||||
assert output_space.contains(result)
|
||||
|
||||
|
||||
@@ -132,18 +132,18 @@ def test_agent_selection():
|
||||
def test_observation_space():
|
||||
observation_space = ObservationSpace(np.array([1, 10]), -10, 10)
|
||||
|
||||
# testing that val_matches_space_definition works
|
||||
assert observation_space.val_matches_space_definition(np.ones([1, 10]))
|
||||
assert not observation_space.val_matches_space_definition(np.ones([2, 10]))
|
||||
assert not observation_space.val_matches_space_definition(np.ones([1, 10]) * 100)
|
||||
assert not observation_space.val_matches_space_definition(np.ones([1, 1, 10]))
|
||||
# testing that contains works
|
||||
assert observation_space.contains(np.ones([1, 10]))
|
||||
assert not observation_space.contains(np.ones([2, 10]))
|
||||
assert not observation_space.contains(np.ones([1, 10]) * 100)
|
||||
assert not observation_space.contains(np.ones([1, 1, 10]))
|
||||
|
||||
# is_point_in_space_shape
|
||||
assert observation_space.is_point_in_space_shape(np.array([0, 9]))
|
||||
assert observation_space.is_point_in_space_shape(np.array([0, 0]))
|
||||
assert not observation_space.is_point_in_space_shape(np.array([1, 8]))
|
||||
assert not observation_space.is_point_in_space_shape(np.array([0, 10]))
|
||||
assert not observation_space.is_point_in_space_shape(np.array([-1, 6]))
|
||||
# is_valid_index
|
||||
assert observation_space.is_valid_index(np.array([0, 9]))
|
||||
assert observation_space.is_valid_index(np.array([0, 0]))
|
||||
assert not observation_space.is_valid_index(np.array([1, 8]))
|
||||
assert not observation_space.is_valid_index(np.array([0, 10]))
|
||||
assert not observation_space.is_valid_index(np.array([-1, 6]))
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
|
||||
Reference in New Issue
Block a user