mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
moving to skimage.transform.resize (#321)
This commit is contained in:
@@ -14,39 +14,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from enum import Enum
|
from skimage.transform import resize
|
||||||
|
|
||||||
import scipy.ndimage
|
|
||||||
|
|
||||||
from rl_coach.core_types import ObservationType
|
from rl_coach.core_types import ObservationType
|
||||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||||
from rl_coach.spaces import ObservationSpace
|
from rl_coach.spaces import ObservationSpace
|
||||||
|
|
||||||
|
|
||||||
# imresize interpolation types as defined by scipy here:
|
|
||||||
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
|
|
||||||
class RescaleInterpolationType(Enum):
|
|
||||||
NEAREST = 'nearest'
|
|
||||||
LANCZOS = 'lanczos'
|
|
||||||
BILINEAR = 'bilinear'
|
|
||||||
BICUBIC = 'bicubic'
|
|
||||||
CUBIC = 'cubic'
|
|
||||||
|
|
||||||
|
|
||||||
class ObservationRescaleSizeByFactorFilter(ObservationFilter):
|
class ObservationRescaleSizeByFactorFilter(ObservationFilter):
|
||||||
"""
|
"""
|
||||||
Rescales an image observation by some factor. For example, the image size
|
Rescales an image observation by some factor. For example, the image size
|
||||||
can be reduced by a factor of 2.
|
can be reduced by a factor of 2.
|
||||||
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, rescale_factor: float, rescaling_interpolation_type: RescaleInterpolationType):
|
def __init__(self, rescale_factor: float):
|
||||||
"""
|
"""
|
||||||
:param rescale_factor: the factor by which the observation will be rescaled
|
:param rescale_factor: the factor by which the observation will be rescaled
|
||||||
:param rescaling_interpolation_type: the interpolation type for rescaling
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rescale_factor = float(rescale_factor) # scipy requires float scale factors
|
self.rescale_factor = float(rescale_factor)
|
||||||
self.rescaling_interpolation_type = rescaling_interpolation_type
|
|
||||||
# TODO: allow selecting the channels dim
|
# TODO: allow selecting the channels dim
|
||||||
|
|
||||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||||
@@ -58,13 +44,14 @@ class ObservationRescaleSizeByFactorFilter(ObservationFilter):
|
|||||||
raise ValueError("Observations with 3 dimensions must have 3 channels in the last axis (RGB)")
|
raise ValueError("Observations with 3 dimensions must have 3 channels in the last axis (RGB)")
|
||||||
|
|
||||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||||
# scipy works only with uint8
|
|
||||||
observation = observation.astype('uint8')
|
observation = observation.astype('uint8')
|
||||||
|
rescaled_output_size = tuple([int(self.rescale_factor * dim) for dim in observation.shape[:2]])
|
||||||
|
|
||||||
|
if len(observation.shape) == 3:
|
||||||
|
rescaled_output_size += (3,)
|
||||||
|
|
||||||
# rescale
|
# rescale
|
||||||
observation = scipy.misc.imresize(observation,
|
observation = resize(observation, rescaled_output_size, anti_aliasing=False, preserve_range=True).astype('uint8')
|
||||||
self.rescale_factor,
|
|
||||||
interp=self.rescaling_interpolation_type.value)
|
|
||||||
|
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|||||||
@@ -15,41 +15,26 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from enum import Enum
|
from skimage.transform import resize
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.ndimage
|
|
||||||
|
|
||||||
from rl_coach.core_types import ObservationType
|
from rl_coach.core_types import ObservationType
|
||||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||||
from rl_coach.spaces import ObservationSpace, PlanarMapsObservationSpace, ImageObservationSpace
|
from rl_coach.spaces import ObservationSpace, PlanarMapsObservationSpace, ImageObservationSpace
|
||||||
|
|
||||||
|
|
||||||
# imresize interpolation types as defined by scipy here:
|
|
||||||
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
|
|
||||||
class RescaleInterpolationType(Enum):
|
|
||||||
NEAREST = 'nearest'
|
|
||||||
LANCZOS = 'lanczos'
|
|
||||||
BILINEAR = 'bilinear'
|
|
||||||
BICUBIC = 'bicubic'
|
|
||||||
CUBIC = 'cubic'
|
|
||||||
|
|
||||||
|
|
||||||
class ObservationRescaleToSizeFilter(ObservationFilter):
|
class ObservationRescaleToSizeFilter(ObservationFilter):
|
||||||
"""
|
"""
|
||||||
Rescales an image observation to a given size. The target size does not
|
Rescales an image observation to a given size. The target size does not
|
||||||
necessarily keep the aspect ratio of the original observation.
|
necessarily keep the aspect ratio of the original observation.
|
||||||
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
|
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
|
||||||
"""
|
"""
|
||||||
def __init__(self, output_observation_space: PlanarMapsObservationSpace,
|
def __init__(self, output_observation_space: PlanarMapsObservationSpace):
|
||||||
rescaling_interpolation_type: RescaleInterpolationType=RescaleInterpolationType.BILINEAR):
|
|
||||||
"""
|
"""
|
||||||
:param output_observation_space: the output observation space
|
:param output_observation_space: the output observation space
|
||||||
:param rescaling_interpolation_type: the interpolation type for rescaling
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.output_observation_space = output_observation_space
|
self.output_observation_space = output_observation_space
|
||||||
self.rescaling_interpolation_type = rescaling_interpolation_type
|
|
||||||
|
|
||||||
if not isinstance(output_observation_space, PlanarMapsObservationSpace):
|
if not isinstance(output_observation_space, PlanarMapsObservationSpace):
|
||||||
raise ValueError("The rescale filter only applies to observation spaces that inherit from "
|
raise ValueError("The rescale filter only applies to observation spaces that inherit from "
|
||||||
@@ -75,20 +60,19 @@ class ObservationRescaleToSizeFilter(ObservationFilter):
|
|||||||
self.output_observation_space.shape[self.output_observation_space.channels_axis]))
|
self.output_observation_space.shape[self.output_observation_space.channels_axis]))
|
||||||
|
|
||||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||||
# scipy works only with uint8
|
|
||||||
observation = observation.astype('uint8')
|
observation = observation.astype('uint8')
|
||||||
|
|
||||||
# rescale
|
# rescale
|
||||||
if isinstance(self.output_observation_space, ImageObservationSpace):
|
if isinstance(self.output_observation_space, ImageObservationSpace):
|
||||||
observation = scipy.misc.imresize(observation,
|
observation = resize(observation, tuple(self.output_observation_space.shape), anti_aliasing=False,
|
||||||
tuple(self.output_observation_space.shape),
|
preserve_range=True).astype('uint8')
|
||||||
interp=self.rescaling_interpolation_type.value)
|
|
||||||
else:
|
else:
|
||||||
new_observation = []
|
new_observation = []
|
||||||
for i in range(self.output_observation_space.shape[self.output_observation_space.channels_axis]):
|
for i in range(self.output_observation_space.shape[self.output_observation_space.channels_axis]):
|
||||||
new_observation.append(scipy.misc.imresize(observation.take(i, self.output_observation_space.channels_axis),
|
new_observation.append(resize(observation.take(i, self.output_observation_space.channels_axis),
|
||||||
tuple(self.planar_map_output_shape),
|
tuple(self.planar_map_output_shape),
|
||||||
interp=self.rescaling_interpolation_type.value))
|
preserve_range=True).astype('uint8'))
|
||||||
new_observation = np.array(new_observation)
|
new_observation = np.array(new_observation)
|
||||||
observation = new_observation.swapaxes(0, self.output_observation_space.channels_axis)
|
observation = new_observation.swapaxes(0, self.output_observation_space.channels_axis)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.filters.observation.observation_rescale_size_by_factor_filter import ObservationRescaleSizeByFactorFilter, RescaleInterpolationType
|
from rl_coach.filters.observation.observation_rescale_size_by_factor_filter import ObservationRescaleSizeByFactorFilter
|
||||||
from rl_coach.spaces import ObservationSpace
|
from rl_coach.spaces import ObservationSpace
|
||||||
from rl_coach.core_types import EnvResponse
|
from rl_coach.core_types import EnvResponse
|
||||||
from rl_coach.filters.filter import InputFilter
|
from rl_coach.filters.filter import InputFilter
|
||||||
@@ -17,7 +17,7 @@ def test_filter():
|
|||||||
env_response = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
|
env_response = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
|
||||||
rescale_filter = InputFilter()
|
rescale_filter = InputFilter()
|
||||||
rescale_filter.add_observation_filter('observation', 'rescale',
|
rescale_filter.add_observation_filter('observation', 'rescale',
|
||||||
ObservationRescaleSizeByFactorFilter(0.5, RescaleInterpolationType.BILINEAR))
|
ObservationRescaleSizeByFactorFilter(0.5))
|
||||||
|
|
||||||
result = rescale_filter.filter(env_response)[0]
|
result = rescale_filter.filter(env_response)[0]
|
||||||
unfiltered_observation = env_response.next_state['observation']
|
unfiltered_observation = env_response.next_state['observation']
|
||||||
@@ -33,7 +33,7 @@ def test_filter():
|
|||||||
env_response = EnvResponse(next_state={'observation': np.ones([20, 30])}, reward=0, game_over=False)
|
env_response = EnvResponse(next_state={'observation': np.ones([20, 30])}, reward=0, game_over=False)
|
||||||
rescale_filter = InputFilter()
|
rescale_filter = InputFilter()
|
||||||
rescale_filter.add_observation_filter('observation', 'rescale',
|
rescale_filter.add_observation_filter('observation', 'rescale',
|
||||||
ObservationRescaleSizeByFactorFilter(2, RescaleInterpolationType.BILINEAR))
|
ObservationRescaleSizeByFactorFilter(2))
|
||||||
result = rescale_filter.filter(env_response)[0]
|
result = rescale_filter.filter(env_response)[0]
|
||||||
filtered_observation = result.next_state['observation']
|
filtered_observation = result.next_state['observation']
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ def test_get_filtered_observation_space():
|
|||||||
# error on wrong number of channels
|
# error on wrong number of channels
|
||||||
rescale_filter = InputFilter()
|
rescale_filter = InputFilter()
|
||||||
rescale_filter.add_observation_filter('observation', 'rescale',
|
rescale_filter.add_observation_filter('observation', 'rescale',
|
||||||
ObservationRescaleSizeByFactorFilter(0.5, RescaleInterpolationType.BILINEAR))
|
ObservationRescaleSizeByFactorFilter(0.5))
|
||||||
observation_space = ObservationSpace(np.array([10, 20, 5]))
|
observation_space = ObservationSpace(np.array([10, 20, 5]))
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
filtered_observation_space = rescale_filter.get_filtered_observation_space('observation', observation_space)
|
filtered_observation_space = rescale_filter.get_filtered_observation_space('observation', observation_space)
|
||||||
@@ -64,3 +64,6 @@ def test_get_filtered_observation_space():
|
|||||||
|
|
||||||
# make sure the original observation space is unchanged
|
# make sure the original observation space is unchanged
|
||||||
assert np.all(observation_space.shape == np.array([10, 20, 3]))
|
assert np.all(observation_space.shape == np.array([10, 20, 3]))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_filter()
|
||||||
@@ -5,7 +5,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType
|
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
|
||||||
from rl_coach.spaces import ObservationSpace, ImageObservationSpace, PlanarMapsObservationSpace
|
from rl_coach.spaces import ObservationSpace, ImageObservationSpace, PlanarMapsObservationSpace
|
||||||
from rl_coach.core_types import EnvResponse
|
from rl_coach.core_types import EnvResponse
|
||||||
from rl_coach.filters.filter import InputFilter
|
from rl_coach.filters.filter import InputFilter
|
||||||
@@ -18,9 +18,8 @@ def test_filter():
|
|||||||
transition = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
|
transition = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
|
||||||
rescale_filter = InputFilter()
|
rescale_filter = InputFilter()
|
||||||
rescale_filter.add_observation_filter('observation', 'rescale',
|
rescale_filter.add_observation_filter('observation', 'rescale',
|
||||||
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 3]),
|
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 3]),
|
||||||
high=255),
|
high=255)))
|
||||||
RescaleInterpolationType.BILINEAR))
|
|
||||||
|
|
||||||
result = rescale_filter.filter(transition)[0]
|
result = rescale_filter.filter(transition)[0]
|
||||||
unfiltered_observation = transition.next_state['observation']
|
unfiltered_observation = transition.next_state['observation']
|
||||||
@@ -38,8 +37,7 @@ def test_filter():
|
|||||||
rescale_filter = InputFilter()
|
rescale_filter = InputFilter()
|
||||||
rescale_filter.add_observation_filter('observation', 'rescale',
|
rescale_filter.add_observation_filter('observation', 'rescale',
|
||||||
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([40, 60]),
|
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([40, 60]),
|
||||||
high=255),
|
high=255)))
|
||||||
RescaleInterpolationType.BILINEAR))
|
|
||||||
result = rescale_filter.filter(transition)[0]
|
result = rescale_filter.filter(transition)[0]
|
||||||
filtered_observation = result.next_state['observation']
|
filtered_observation = result.next_state['observation']
|
||||||
|
|
||||||
@@ -52,21 +50,20 @@ def test_filter():
|
|||||||
# InputFilter(
|
# InputFilter(
|
||||||
# observation_filters=OrderedDict([('rescale',
|
# observation_filters=OrderedDict([('rescale',
|
||||||
# ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 1]),
|
# ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 1]),
|
||||||
# high=255),
|
# high=255)
|
||||||
# RescaleInterpolationType.BILINEAR))]))
|
# ))]))
|
||||||
|
|
||||||
# TODO: validate input to filter
|
# TODO: validate input to filter
|
||||||
# different number of axes -> error
|
# different number of axes -> error
|
||||||
# env_response = EnvResponse(state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
|
# env_response = EnvResponse(state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
|
||||||
# rescale_filter = ObservationRescaleToSizeFilter(ObservationSpace(np.array([10, 20])),
|
# rescale_filter = ObservationRescaleToSizeFilter(ObservationSpace(np.array([10, 20]))
|
||||||
# RescaleInterpolationType.BILINEAR)
|
# )
|
||||||
# with pytest.raises(ValueError):
|
# with pytest.raises(ValueError):
|
||||||
# result = rescale_filter.filter(transition)
|
# result = rescale_filter.filter(transition)
|
||||||
|
|
||||||
# channels first -> error
|
# channels first -> error
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([3, 10, 20]), high=255),
|
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([3, 10, 20]), high=255))
|
||||||
RescaleInterpolationType.BILINEAR)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit_test
|
@pytest.mark.unit_test
|
||||||
@@ -76,15 +73,13 @@ def test_get_filtered_observation_space():
|
|||||||
observation_filters = InputFilter()
|
observation_filters = InputFilter()
|
||||||
observation_filters.add_observation_filter('observation', 'rescale',
|
observation_filters.add_observation_filter('observation', 'rescale',
|
||||||
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 5]),
|
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 5]),
|
||||||
high=255),
|
high=255)))
|
||||||
RescaleInterpolationType.BILINEAR))
|
|
||||||
|
|
||||||
# mismatch and wrong number of channels
|
# mismatch and wrong number of channels
|
||||||
rescale_filter = InputFilter()
|
rescale_filter = InputFilter()
|
||||||
rescale_filter.add_observation_filter('observation', 'rescale',
|
rescale_filter.add_observation_filter('observation', 'rescale',
|
||||||
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 3]),
|
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 3]),
|
||||||
high=255),
|
high=255)))
|
||||||
RescaleInterpolationType.BILINEAR))
|
|
||||||
|
|
||||||
observation_space = PlanarMapsObservationSpace(np.array([10, 20, 5]), low=0, high=255)
|
observation_space = PlanarMapsObservationSpace(np.array([10, 20, 5]), low=0, high=255)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType
|
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
|
||||||
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||||
from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter
|
from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter
|
||||||
from rl_coach.filters.observation.observation_stacking_filter import ObservationStackingFilter
|
from rl_coach.filters.observation.observation_stacking_filter import ObservationStackingFilter
|
||||||
@@ -31,7 +31,6 @@ def test_filter_stacking():
|
|||||||
|
|
||||||
filter1 = ObservationRescaleToSizeFilter(
|
filter1 = ObservationRescaleToSizeFilter(
|
||||||
output_observation_space=ImageObservationSpace(np.array([110, 84]), high=255),
|
output_observation_space=ImageObservationSpace(np.array([110, 84]), high=255),
|
||||||
rescaling_interpolation_type=RescaleInterpolationType.BILINEAR
|
|
||||||
)
|
)
|
||||||
|
|
||||||
filter2 = ObservationCropFilter(
|
filter2 = ObservationCropFilter(
|
||||||
|
|||||||
Reference in New Issue
Block a user