mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
more clear names for methods of Space (#181)
* rename Space.val_matches_space_definition -> contains; Space.is_point_in_space_shape -> valid_index * rename valid_index -> is_valid_index
This commit is contained in:
@@ -47,7 +47,7 @@ class ActionFilter(Filter):
|
||||
:param action: an action to validate
|
||||
:return: None
|
||||
"""
|
||||
if not self.output_action_space.val_matches_space_definition(action):
|
||||
if not self.output_action_space.contains(action):
|
||||
raise ValueError("The given action ({}) does not match the action space ({})"
|
||||
.format(action, self.output_action_space))
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class PartialDiscreteActionSpaceMap(ActionFilter):
|
||||
if not self.target_actions:
|
||||
raise ValueError("The target actions were not set")
|
||||
for v in self.target_actions:
|
||||
if not output_action_space.val_matches_space_definition(v):
|
||||
if not output_action_space.contains(v):
|
||||
raise ValueError("The values in the output actions ({}) do not match the output action "
|
||||
"space definition ({})".format(v, output_action_space))
|
||||
|
||||
|
||||
@@ -71,8 +71,8 @@ class ObservationCropFilter(ObservationFilter):
|
||||
if np.any(crop_high > input_observation_space.shape) or \
|
||||
np.any(crop_low > input_observation_space.shape):
|
||||
raise ValueError("The cropping values are outside of the observation space")
|
||||
if not input_observation_space.is_point_in_space_shape(crop_low) or \
|
||||
not input_observation_space.is_point_in_space_shape(crop_high - 1):
|
||||
if not input_observation_space.is_valid_index(crop_low) or \
|
||||
not input_observation_space.is_valid_index(crop_high - 1):
|
||||
raise ValueError("The cropping indices are outside of the observation space")
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
|
||||
@@ -79,7 +79,7 @@ class ObservationStackingFilter(ObservationFilter):
|
||||
"functionality")
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
if len(self.stack) > 0 and not input_observation_space.val_matches_space_definition(self.stack[-1]):
|
||||
if len(self.stack) > 0 and not input_observation_space.contains(self.stack[-1]):
|
||||
raise ValueError("The given input observation space is different than the observations already stored in"
|
||||
"the filters memory")
|
||||
if input_observation_space.num_dimensions <= self.stacking_axis:
|
||||
|
||||
Reference in New Issue
Block a user