From 30c2b2fc4586d1bf6b50e1b6f2c2a5c04acf2947 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Thu, 23 May 2019 13:38:01 +0300 Subject: [PATCH] moving to skimage.transform.resize (#321) --- ...servation_rescale_size_by_factor_filter.py | 29 +++++------------ .../observation_rescale_to_size_filter.py | 32 +++++-------------- ...servation_rescale_size_by_factor_filter.py | 11 ++++--- ...test_observation_rescale_to_size_filter.py | 27 +++++++--------- .../tests/filters/test_filters_stacking.py | 3 +- 5 files changed, 35 insertions(+), 67 deletions(-) diff --git a/rl_coach/filters/observation/observation_rescale_size_by_factor_filter.py b/rl_coach/filters/observation/observation_rescale_size_by_factor_filter.py index 6d8f07d..df7cc08 100644 --- a/rl_coach/filters/observation/observation_rescale_size_by_factor_filter.py +++ b/rl_coach/filters/observation/observation_rescale_size_by_factor_filter.py @@ -14,39 +14,25 @@ # limitations under the License. # -from enum import Enum +from skimage.transform import resize -import scipy.ndimage from rl_coach.core_types import ObservationType from rl_coach.filters.observation.observation_filter import ObservationFilter from rl_coach.spaces import ObservationSpace -# imresize interpolation types as defined by scipy here: -# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html -class RescaleInterpolationType(Enum): - NEAREST = 'nearest' - LANCZOS = 'lanczos' - BILINEAR = 'bilinear' - BICUBIC = 'bicubic' - CUBIC = 'cubic' - - class ObservationRescaleSizeByFactorFilter(ObservationFilter): """ Rescales an image observation by some factor. For example, the image size can be reduced by a factor of 2. - Warning: this requires the input observation to be of type uint8 due to scipy requirements! """ - def __init__(self, rescale_factor: float, rescaling_interpolation_type: RescaleInterpolationType): + def __init__(self, rescale_factor: float): """ :param rescale_factor: the factor by which the observation will be rescaled - :param rescaling_interpolation_type: the interpolation type for rescaling """ super().__init__() - self.rescale_factor = float(rescale_factor) # scipy requires float scale factors - self.rescaling_interpolation_type = rescaling_interpolation_type + self.rescale_factor = float(rescale_factor) # TODO: allow selecting the channels dim def validate_input_observation_space(self, input_observation_space: ObservationSpace): @@ -58,13 +44,14 @@ class ObservationRescaleSizeByFactorFilter(ObservationFilter): raise ValueError("Observations with 3 dimensions must have 3 channels in the last axis (RGB)") def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType: - # scipy works only with uint8 observation = observation.astype('uint8') + rescaled_output_size = tuple([int(self.rescale_factor * dim) for dim in observation.shape[:2]]) + + if len(observation.shape) == 3: + rescaled_output_size += (3,) # rescale - observation = scipy.misc.imresize(observation, - self.rescale_factor, - interp=self.rescaling_interpolation_type.value) + observation = resize(observation, rescaled_output_size, anti_aliasing=False, preserve_range=True).astype('uint8') return observation diff --git a/rl_coach/filters/observation/observation_rescale_to_size_filter.py b/rl_coach/filters/observation/observation_rescale_to_size_filter.py index 9037b8c..4dcff56 100644 --- a/rl_coach/filters/observation/observation_rescale_to_size_filter.py +++ b/rl_coach/filters/observation/observation_rescale_to_size_filter.py @@ -15,41 +15,26 @@ # import copy -from enum import Enum - +from skimage.transform import resize import numpy as np -import scipy.ndimage from rl_coach.core_types import ObservationType from rl_coach.filters.observation.observation_filter import ObservationFilter from rl_coach.spaces import ObservationSpace, PlanarMapsObservationSpace, ImageObservationSpace -# imresize interpolation types as defined by scipy here: -# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html -class RescaleInterpolationType(Enum): - NEAREST = 'nearest' - LANCZOS = 'lanczos' - BILINEAR = 'bilinear' - BICUBIC = 'bicubic' - CUBIC = 'cubic' - - class ObservationRescaleToSizeFilter(ObservationFilter): """ Rescales an image observation to a given size. The target size does not necessarily keep the aspect ratio of the original observation. Warning: this requires the input observation to be of type uint8 due to scipy requirements! """ - def __init__(self, output_observation_space: PlanarMapsObservationSpace, - rescaling_interpolation_type: RescaleInterpolationType=RescaleInterpolationType.BILINEAR): + def __init__(self, output_observation_space: PlanarMapsObservationSpace): """ :param output_observation_space: the output observation space - :param rescaling_interpolation_type: the interpolation type for rescaling """ super().__init__() self.output_observation_space = output_observation_space - self.rescaling_interpolation_type = rescaling_interpolation_type if not isinstance(output_observation_space, PlanarMapsObservationSpace): raise ValueError("The rescale filter only applies to observation spaces that inherit from " @@ -75,20 +60,19 @@ class ObservationRescaleToSizeFilter(ObservationFilter): self.output_observation_space.shape[self.output_observation_space.channels_axis])) def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType: - # scipy works only with uint8 observation = observation.astype('uint8') # rescale if isinstance(self.output_observation_space, ImageObservationSpace): - observation = scipy.misc.imresize(observation, - tuple(self.output_observation_space.shape), - interp=self.rescaling_interpolation_type.value) + observation = resize(observation, tuple(self.output_observation_space.shape), anti_aliasing=False, + preserve_range=True).astype('uint8') + else: new_observation = [] for i in range(self.output_observation_space.shape[self.output_observation_space.channels_axis]): - new_observation.append(scipy.misc.imresize(observation.take(i, self.output_observation_space.channels_axis), - tuple(self.planar_map_output_shape), - interp=self.rescaling_interpolation_type.value)) + new_observation.append(resize(observation.take(i, self.output_observation_space.channels_axis), + tuple(self.planar_map_output_shape), + preserve_range=True).astype('uint8')) new_observation = np.array(new_observation) observation = new_observation.swapaxes(0, self.output_observation_space.channels_axis) diff --git a/rl_coach/tests/filters/observation/test_observation_rescale_size_by_factor_filter.py b/rl_coach/tests/filters/observation/test_observation_rescale_size_by_factor_filter.py index e5703d0..c3e8f65 100644 --- a/rl_coach/tests/filters/observation/test_observation_rescale_size_by_factor_filter.py +++ b/rl_coach/tests/filters/observation/test_observation_rescale_size_by_factor_filter.py @@ -5,7 +5,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import pytest import numpy as np -from rl_coach.filters.observation.observation_rescale_size_by_factor_filter import ObservationRescaleSizeByFactorFilter, RescaleInterpolationType +from rl_coach.filters.observation.observation_rescale_size_by_factor_filter import ObservationRescaleSizeByFactorFilter from rl_coach.spaces import ObservationSpace from rl_coach.core_types import EnvResponse from rl_coach.filters.filter import InputFilter @@ -17,7 +17,7 @@ def test_filter(): env_response = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False) rescale_filter = InputFilter() rescale_filter.add_observation_filter('observation', 'rescale', - ObservationRescaleSizeByFactorFilter(0.5, RescaleInterpolationType.BILINEAR)) + ObservationRescaleSizeByFactorFilter(0.5)) result = rescale_filter.filter(env_response)[0] unfiltered_observation = env_response.next_state['observation'] @@ -33,7 +33,7 @@ def test_filter(): env_response = EnvResponse(next_state={'observation': np.ones([20, 30])}, reward=0, game_over=False) rescale_filter = InputFilter() rescale_filter.add_observation_filter('observation', 'rescale', - ObservationRescaleSizeByFactorFilter(2, RescaleInterpolationType.BILINEAR)) + ObservationRescaleSizeByFactorFilter(2)) result = rescale_filter.filter(env_response)[0] filtered_observation = result.next_state['observation'] @@ -47,7 +47,7 @@ def test_get_filtered_observation_space(): # error on wrong number of channels rescale_filter = InputFilter() rescale_filter.add_observation_filter('observation', 'rescale', - ObservationRescaleSizeByFactorFilter(0.5, RescaleInterpolationType.BILINEAR)) + ObservationRescaleSizeByFactorFilter(0.5)) observation_space = ObservationSpace(np.array([10, 20, 5])) with pytest.raises(ValueError): filtered_observation_space = rescale_filter.get_filtered_observation_space('observation', observation_space) @@ -64,3 +64,6 @@ def test_get_filtered_observation_space(): # make sure the original observation space is unchanged assert np.all(observation_space.shape == np.array([10, 20, 3])) + +if __name__ == '__main__': + test_filter() \ No newline at end of file diff --git a/rl_coach/tests/filters/observation/test_observation_rescale_to_size_filter.py b/rl_coach/tests/filters/observation/test_observation_rescale_to_size_filter.py index 48b00e1..f960b08 100644 --- a/rl_coach/tests/filters/observation/test_observation_rescale_to_size_filter.py +++ b/rl_coach/tests/filters/observation/test_observation_rescale_to_size_filter.py @@ -5,7 +5,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import pytest import numpy as np -from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType +from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter from rl_coach.spaces import ObservationSpace, ImageObservationSpace, PlanarMapsObservationSpace from rl_coach.core_types import EnvResponse from rl_coach.filters.filter import InputFilter @@ -18,9 +18,8 @@ def test_filter(): transition = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False) rescale_filter = InputFilter() rescale_filter.add_observation_filter('observation', 'rescale', - ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 3]), - high=255), - RescaleInterpolationType.BILINEAR)) + ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 3]), + high=255))) result = rescale_filter.filter(transition)[0] unfiltered_observation = transition.next_state['observation'] @@ -38,8 +37,7 @@ def test_filter(): rescale_filter = InputFilter() rescale_filter.add_observation_filter('observation', 'rescale', ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([40, 60]), - high=255), - RescaleInterpolationType.BILINEAR)) + high=255))) result = rescale_filter.filter(transition)[0] filtered_observation = result.next_state['observation'] @@ -52,21 +50,20 @@ def test_filter(): # InputFilter( # observation_filters=OrderedDict([('rescale', # ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 1]), - # high=255), - # RescaleInterpolationType.BILINEAR))])) + # high=255) + # ))])) # TODO: validate input to filter # different number of axes -> error # env_response = EnvResponse(state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False) - # rescale_filter = ObservationRescaleToSizeFilter(ObservationSpace(np.array([10, 20])), - # RescaleInterpolationType.BILINEAR) + # rescale_filter = ObservationRescaleToSizeFilter(ObservationSpace(np.array([10, 20])) + # ) # with pytest.raises(ValueError): # result = rescale_filter.filter(transition) # channels first -> error with pytest.raises(ValueError): - ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([3, 10, 20]), high=255), - RescaleInterpolationType.BILINEAR) + ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([3, 10, 20]), high=255)) @pytest.mark.unit_test @@ -76,15 +73,13 @@ def test_get_filtered_observation_space(): observation_filters = InputFilter() observation_filters.add_observation_filter('observation', 'rescale', ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 5]), - high=255), - RescaleInterpolationType.BILINEAR)) + high=255))) # mismatch and wrong number of channels rescale_filter = InputFilter() rescale_filter.add_observation_filter('observation', 'rescale', ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 3]), - high=255), - RescaleInterpolationType.BILINEAR)) + high=255))) observation_space = PlanarMapsObservationSpace(np.array([10, 20, 5]), low=0, high=255) with pytest.raises(ValueError): diff --git a/rl_coach/tests/filters/test_filters_stacking.py b/rl_coach/tests/filters/test_filters_stacking.py index 3534243..1469ad5 100644 --- a/rl_coach/tests/filters/test_filters_stacking.py +++ b/rl_coach/tests/filters/test_filters_stacking.py @@ -4,7 +4,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import pytest -from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType +from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter from rl_coach.filters.observation.observation_stacking_filter import ObservationStackingFilter @@ -31,7 +31,6 @@ def test_filter_stacking(): filter1 = ObservationRescaleToSizeFilter( output_observation_space=ImageObservationSpace(np.array([110, 84]), high=255), - rescaling_interpolation_type=RescaleInterpolationType.BILINEAR ) filter2 = ObservationCropFilter(