1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

more clear names for methods of Space (#181)

* rename Space.val_matches_space_definition -> contains; Space.is_point_in_space_shape -> valid_index
* rename valid_index -> is_valid_index
This commit is contained in:
Zach Dwiel
2019-01-14 15:02:53 -05:00
committed by GitHub
parent 0ccc333d77
commit cd812b0d25
19 changed files with 77 additions and 62 deletions

View File

@@ -47,7 +47,7 @@ class ActionFilter(Filter):
:param action: an action to validate
:return: None
"""
if not self.output_action_space.val_matches_space_definition(action):
if not self.output_action_space.contains(action):
raise ValueError("The given action ({}) does not match the action space ({})"
.format(action, self.output_action_space))

View File

@@ -42,7 +42,7 @@ class PartialDiscreteActionSpaceMap(ActionFilter):
if not self.target_actions:
raise ValueError("The target actions were not set")
for v in self.target_actions:
if not output_action_space.val_matches_space_definition(v):
if not output_action_space.contains(v):
raise ValueError("The values in the output actions ({}) do not match the output action "
"space definition ({})".format(v, output_action_space))

View File

@@ -71,8 +71,8 @@ class ObservationCropFilter(ObservationFilter):
if np.any(crop_high > input_observation_space.shape) or \
np.any(crop_low > input_observation_space.shape):
raise ValueError("The cropping values are outside of the observation space")
if not input_observation_space.is_point_in_space_shape(crop_low) or \
not input_observation_space.is_point_in_space_shape(crop_high - 1):
if not input_observation_space.is_valid_index(crop_low) or \
not input_observation_space.is_valid_index(crop_high - 1):
raise ValueError("The cropping indices are outside of the observation space")
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:

View File

@@ -79,7 +79,7 @@ class ObservationStackingFilter(ObservationFilter):
"functionality")
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
if len(self.stack) > 0 and not input_observation_space.val_matches_space_definition(self.stack[-1]):
if len(self.stack) > 0 and not input_observation_space.contains(self.stack[-1]):
raise ValueError("The given input observation space is different than the observations already stored in"
"the filters memory")
if input_observation_space.num_dimensions <= self.stacking_axis: