1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

moving to skimage.transform.resize (#321)

This commit is contained in:
Gal Leibovich
2019-05-23 13:38:01 +03:00
committed by GitHub
parent acceb03ac0
commit 30c2b2fc45
5 changed files with 35 additions and 67 deletions

View File

@@ -14,39 +14,25 @@
# limitations under the License. # limitations under the License.
# #
from enum import Enum from skimage.transform import resize
import scipy.ndimage
from rl_coach.core_types import ObservationType from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace from rl_coach.spaces import ObservationSpace
# imresize interpolation types as defined by scipy here:
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
class RescaleInterpolationType(Enum):
NEAREST = 'nearest'
LANCZOS = 'lanczos'
BILINEAR = 'bilinear'
BICUBIC = 'bicubic'
CUBIC = 'cubic'
class ObservationRescaleSizeByFactorFilter(ObservationFilter): class ObservationRescaleSizeByFactorFilter(ObservationFilter):
""" """
Rescales an image observation by some factor. For example, the image size Rescales an image observation by some factor. For example, the image size
can be reduced by a factor of 2. can be reduced by a factor of 2.
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
""" """
def __init__(self, rescale_factor: float, rescaling_interpolation_type: RescaleInterpolationType): def __init__(self, rescale_factor: float):
""" """
:param rescale_factor: the factor by which the observation will be rescaled :param rescale_factor: the factor by which the observation will be rescaled
:param rescaling_interpolation_type: the interpolation type for rescaling
""" """
super().__init__() super().__init__()
self.rescale_factor = float(rescale_factor) # scipy requires float scale factors self.rescale_factor = float(rescale_factor)
self.rescaling_interpolation_type = rescaling_interpolation_type
# TODO: allow selecting the channels dim # TODO: allow selecting the channels dim
def validate_input_observation_space(self, input_observation_space: ObservationSpace): def validate_input_observation_space(self, input_observation_space: ObservationSpace):
@@ -58,13 +44,14 @@ class ObservationRescaleSizeByFactorFilter(ObservationFilter):
raise ValueError("Observations with 3 dimensions must have 3 channels in the last axis (RGB)") raise ValueError("Observations with 3 dimensions must have 3 channels in the last axis (RGB)")
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType: def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
# scipy works only with uint8
observation = observation.astype('uint8') observation = observation.astype('uint8')
rescaled_output_size = tuple([int(self.rescale_factor * dim) for dim in observation.shape[:2]])
if len(observation.shape) == 3:
rescaled_output_size += (3,)
# rescale # rescale
observation = scipy.misc.imresize(observation, observation = resize(observation, rescaled_output_size, anti_aliasing=False, preserve_range=True).astype('uint8')
self.rescale_factor,
interp=self.rescaling_interpolation_type.value)
return observation return observation

View File

@@ -15,41 +15,26 @@
# #
import copy import copy
from enum import Enum from skimage.transform import resize
import numpy as np import numpy as np
import scipy.ndimage
from rl_coach.core_types import ObservationType from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace, PlanarMapsObservationSpace, ImageObservationSpace from rl_coach.spaces import ObservationSpace, PlanarMapsObservationSpace, ImageObservationSpace
# imresize interpolation types as defined by scipy here:
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
class RescaleInterpolationType(Enum):
NEAREST = 'nearest'
LANCZOS = 'lanczos'
BILINEAR = 'bilinear'
BICUBIC = 'bicubic'
CUBIC = 'cubic'
class ObservationRescaleToSizeFilter(ObservationFilter): class ObservationRescaleToSizeFilter(ObservationFilter):
""" """
Rescales an image observation to a given size. The target size does not Rescales an image observation to a given size. The target size does not
necessarily keep the aspect ratio of the original observation. necessarily keep the aspect ratio of the original observation.
Warning: this requires the input observation to be of type uint8 due to scipy requirements! Warning: this requires the input observation to be of type uint8 due to scipy requirements!
""" """
def __init__(self, output_observation_space: PlanarMapsObservationSpace, def __init__(self, output_observation_space: PlanarMapsObservationSpace):
rescaling_interpolation_type: RescaleInterpolationType=RescaleInterpolationType.BILINEAR):
""" """
:param output_observation_space: the output observation space :param output_observation_space: the output observation space
:param rescaling_interpolation_type: the interpolation type for rescaling
""" """
super().__init__() super().__init__()
self.output_observation_space = output_observation_space self.output_observation_space = output_observation_space
self.rescaling_interpolation_type = rescaling_interpolation_type
if not isinstance(output_observation_space, PlanarMapsObservationSpace): if not isinstance(output_observation_space, PlanarMapsObservationSpace):
raise ValueError("The rescale filter only applies to observation spaces that inherit from " raise ValueError("The rescale filter only applies to observation spaces that inherit from "
@@ -75,20 +60,19 @@ class ObservationRescaleToSizeFilter(ObservationFilter):
self.output_observation_space.shape[self.output_observation_space.channels_axis])) self.output_observation_space.shape[self.output_observation_space.channels_axis]))
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType: def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
# scipy works only with uint8
observation = observation.astype('uint8') observation = observation.astype('uint8')
# rescale # rescale
if isinstance(self.output_observation_space, ImageObservationSpace): if isinstance(self.output_observation_space, ImageObservationSpace):
observation = scipy.misc.imresize(observation, observation = resize(observation, tuple(self.output_observation_space.shape), anti_aliasing=False,
tuple(self.output_observation_space.shape), preserve_range=True).astype('uint8')
interp=self.rescaling_interpolation_type.value)
else: else:
new_observation = [] new_observation = []
for i in range(self.output_observation_space.shape[self.output_observation_space.channels_axis]): for i in range(self.output_observation_space.shape[self.output_observation_space.channels_axis]):
new_observation.append(scipy.misc.imresize(observation.take(i, self.output_observation_space.channels_axis), new_observation.append(resize(observation.take(i, self.output_observation_space.channels_axis),
tuple(self.planar_map_output_shape), tuple(self.planar_map_output_shape),
interp=self.rescaling_interpolation_type.value)) preserve_range=True).astype('uint8'))
new_observation = np.array(new_observation) new_observation = np.array(new_observation)
observation = new_observation.swapaxes(0, self.output_observation_space.channels_axis) observation = new_observation.swapaxes(0, self.output_observation_space.channels_axis)

View File

@@ -5,7 +5,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import pytest import pytest
import numpy as np import numpy as np
from rl_coach.filters.observation.observation_rescale_size_by_factor_filter import ObservationRescaleSizeByFactorFilter, RescaleInterpolationType from rl_coach.filters.observation.observation_rescale_size_by_factor_filter import ObservationRescaleSizeByFactorFilter
from rl_coach.spaces import ObservationSpace from rl_coach.spaces import ObservationSpace
from rl_coach.core_types import EnvResponse from rl_coach.core_types import EnvResponse
from rl_coach.filters.filter import InputFilter from rl_coach.filters.filter import InputFilter
@@ -17,7 +17,7 @@ def test_filter():
env_response = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False) env_response = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
rescale_filter = InputFilter() rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale', rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleSizeByFactorFilter(0.5, RescaleInterpolationType.BILINEAR)) ObservationRescaleSizeByFactorFilter(0.5))
result = rescale_filter.filter(env_response)[0] result = rescale_filter.filter(env_response)[0]
unfiltered_observation = env_response.next_state['observation'] unfiltered_observation = env_response.next_state['observation']
@@ -33,7 +33,7 @@ def test_filter():
env_response = EnvResponse(next_state={'observation': np.ones([20, 30])}, reward=0, game_over=False) env_response = EnvResponse(next_state={'observation': np.ones([20, 30])}, reward=0, game_over=False)
rescale_filter = InputFilter() rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale', rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleSizeByFactorFilter(2, RescaleInterpolationType.BILINEAR)) ObservationRescaleSizeByFactorFilter(2))
result = rescale_filter.filter(env_response)[0] result = rescale_filter.filter(env_response)[0]
filtered_observation = result.next_state['observation'] filtered_observation = result.next_state['observation']
@@ -47,7 +47,7 @@ def test_get_filtered_observation_space():
# error on wrong number of channels # error on wrong number of channels
rescale_filter = InputFilter() rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale', rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleSizeByFactorFilter(0.5, RescaleInterpolationType.BILINEAR)) ObservationRescaleSizeByFactorFilter(0.5))
observation_space = ObservationSpace(np.array([10, 20, 5])) observation_space = ObservationSpace(np.array([10, 20, 5]))
with pytest.raises(ValueError): with pytest.raises(ValueError):
filtered_observation_space = rescale_filter.get_filtered_observation_space('observation', observation_space) filtered_observation_space = rescale_filter.get_filtered_observation_space('observation', observation_space)
@@ -64,3 +64,6 @@ def test_get_filtered_observation_space():
# make sure the original observation space is unchanged # make sure the original observation space is unchanged
assert np.all(observation_space.shape == np.array([10, 20, 3])) assert np.all(observation_space.shape == np.array([10, 20, 3]))
if __name__ == '__main__':
test_filter()

View File

@@ -5,7 +5,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import pytest import pytest
import numpy as np import numpy as np
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
from rl_coach.spaces import ObservationSpace, ImageObservationSpace, PlanarMapsObservationSpace from rl_coach.spaces import ObservationSpace, ImageObservationSpace, PlanarMapsObservationSpace
from rl_coach.core_types import EnvResponse from rl_coach.core_types import EnvResponse
from rl_coach.filters.filter import InputFilter from rl_coach.filters.filter import InputFilter
@@ -19,8 +19,7 @@ def test_filter():
rescale_filter = InputFilter() rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale', rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 3]), ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 3]),
high=255), high=255)))
RescaleInterpolationType.BILINEAR))
result = rescale_filter.filter(transition)[0] result = rescale_filter.filter(transition)[0]
unfiltered_observation = transition.next_state['observation'] unfiltered_observation = transition.next_state['observation']
@@ -38,8 +37,7 @@ def test_filter():
rescale_filter = InputFilter() rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale', rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([40, 60]), ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([40, 60]),
high=255), high=255)))
RescaleInterpolationType.BILINEAR))
result = rescale_filter.filter(transition)[0] result = rescale_filter.filter(transition)[0]
filtered_observation = result.next_state['observation'] filtered_observation = result.next_state['observation']
@@ -52,21 +50,20 @@ def test_filter():
# InputFilter( # InputFilter(
# observation_filters=OrderedDict([('rescale', # observation_filters=OrderedDict([('rescale',
# ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 1]), # ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 1]),
# high=255), # high=255)
# RescaleInterpolationType.BILINEAR))])) # ))]))
# TODO: validate input to filter # TODO: validate input to filter
# different number of axes -> error # different number of axes -> error
# env_response = EnvResponse(state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False) # env_response = EnvResponse(state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
# rescale_filter = ObservationRescaleToSizeFilter(ObservationSpace(np.array([10, 20])), # rescale_filter = ObservationRescaleToSizeFilter(ObservationSpace(np.array([10, 20]))
# RescaleInterpolationType.BILINEAR) # )
# with pytest.raises(ValueError): # with pytest.raises(ValueError):
# result = rescale_filter.filter(transition) # result = rescale_filter.filter(transition)
# channels first -> error # channels first -> error
with pytest.raises(ValueError): with pytest.raises(ValueError):
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([3, 10, 20]), high=255), ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([3, 10, 20]), high=255))
RescaleInterpolationType.BILINEAR)
@pytest.mark.unit_test @pytest.mark.unit_test
@@ -76,15 +73,13 @@ def test_get_filtered_observation_space():
observation_filters = InputFilter() observation_filters = InputFilter()
observation_filters.add_observation_filter('observation', 'rescale', observation_filters.add_observation_filter('observation', 'rescale',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 5]), ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 5]),
high=255), high=255)))
RescaleInterpolationType.BILINEAR))
# mismatch and wrong number of channels # mismatch and wrong number of channels
rescale_filter = InputFilter() rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale', rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 3]), ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 3]),
high=255), high=255)))
RescaleInterpolationType.BILINEAR))
observation_space = PlanarMapsObservationSpace(np.array([10, 20, 5]), low=0, high=255) observation_space = PlanarMapsObservationSpace(np.array([10, 20, 5]), low=0, high=255)
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@@ -4,7 +4,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import pytest import pytest
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter
from rl_coach.filters.observation.observation_stacking_filter import ObservationStackingFilter from rl_coach.filters.observation.observation_stacking_filter import ObservationStackingFilter
@@ -31,7 +31,6 @@ def test_filter_stacking():
filter1 = ObservationRescaleToSizeFilter( filter1 = ObservationRescaleToSizeFilter(
output_observation_space=ImageObservationSpace(np.array([110, 84]), high=255), output_observation_space=ImageObservationSpace(np.array([110, 84]), high=255),
rescaling_interpolation_type=RescaleInterpolationType.BILINEAR
) )
filter2 = ObservationCropFilter( filter2 = ObservationCropFilter(