1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

moving to skimage.transform.resize (#321)

This commit is contained in:
Gal Leibovich
2019-05-23 13:38:01 +03:00
committed by GitHub
parent acceb03ac0
commit 30c2b2fc45
5 changed files with 35 additions and 67 deletions

View File

@@ -14,39 +14,25 @@
# limitations under the License.
#
from enum import Enum
from skimage.transform import resize
import scipy.ndimage
from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace
# imresize interpolation types as defined by scipy here:
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
class RescaleInterpolationType(Enum):
NEAREST = 'nearest'
LANCZOS = 'lanczos'
BILINEAR = 'bilinear'
BICUBIC = 'bicubic'
CUBIC = 'cubic'
class ObservationRescaleSizeByFactorFilter(ObservationFilter):
"""
Rescales an image observation by some factor. For example, the image size
can be reduced by a factor of 2.
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
"""
def __init__(self, rescale_factor: float, rescaling_interpolation_type: RescaleInterpolationType):
def __init__(self, rescale_factor: float):
"""
:param rescale_factor: the factor by which the observation will be rescaled
:param rescaling_interpolation_type: the interpolation type for rescaling
"""
super().__init__()
self.rescale_factor = float(rescale_factor) # scipy requires float scale factors
self.rescaling_interpolation_type = rescaling_interpolation_type
self.rescale_factor = float(rescale_factor)
# TODO: allow selecting the channels dim
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
@@ -58,13 +44,14 @@ class ObservationRescaleSizeByFactorFilter(ObservationFilter):
raise ValueError("Observations with 3 dimensions must have 3 channels in the last axis (RGB)")
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
# scipy works only with uint8
observation = observation.astype('uint8')
rescaled_output_size = tuple([int(self.rescale_factor * dim) for dim in observation.shape[:2]])
if len(observation.shape) == 3:
rescaled_output_size += (3,)
# rescale
observation = scipy.misc.imresize(observation,
self.rescale_factor,
interp=self.rescaling_interpolation_type.value)
observation = resize(observation, rescaled_output_size, anti_aliasing=False, preserve_range=True).astype('uint8')
return observation

View File

@@ -15,41 +15,26 @@
#
import copy
from enum import Enum
from skimage.transform import resize
import numpy as np
import scipy.ndimage
from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace, PlanarMapsObservationSpace, ImageObservationSpace
# imresize interpolation types as defined by scipy here:
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
class RescaleInterpolationType(Enum):
NEAREST = 'nearest'
LANCZOS = 'lanczos'
BILINEAR = 'bilinear'
BICUBIC = 'bicubic'
CUBIC = 'cubic'
class ObservationRescaleToSizeFilter(ObservationFilter):
"""
Rescales an image observation to a given size. The target size does not
necessarily keep the aspect ratio of the original observation.
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
"""
def __init__(self, output_observation_space: PlanarMapsObservationSpace,
rescaling_interpolation_type: RescaleInterpolationType=RescaleInterpolationType.BILINEAR):
def __init__(self, output_observation_space: PlanarMapsObservationSpace):
"""
:param output_observation_space: the output observation space
:param rescaling_interpolation_type: the interpolation type for rescaling
"""
super().__init__()
self.output_observation_space = output_observation_space
self.rescaling_interpolation_type = rescaling_interpolation_type
if not isinstance(output_observation_space, PlanarMapsObservationSpace):
raise ValueError("The rescale filter only applies to observation spaces that inherit from "
@@ -75,20 +60,19 @@ class ObservationRescaleToSizeFilter(ObservationFilter):
self.output_observation_space.shape[self.output_observation_space.channels_axis]))
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
# scipy works only with uint8
observation = observation.astype('uint8')
# rescale
if isinstance(self.output_observation_space, ImageObservationSpace):
observation = scipy.misc.imresize(observation,
tuple(self.output_observation_space.shape),
interp=self.rescaling_interpolation_type.value)
observation = resize(observation, tuple(self.output_observation_space.shape), anti_aliasing=False,
preserve_range=True).astype('uint8')
else:
new_observation = []
for i in range(self.output_observation_space.shape[self.output_observation_space.channels_axis]):
new_observation.append(scipy.misc.imresize(observation.take(i, self.output_observation_space.channels_axis),
tuple(self.planar_map_output_shape),
interp=self.rescaling_interpolation_type.value))
new_observation.append(resize(observation.take(i, self.output_observation_space.channels_axis),
tuple(self.planar_map_output_shape),
preserve_range=True).astype('uint8'))
new_observation = np.array(new_observation)
observation = new_observation.swapaxes(0, self.output_observation_space.channels_axis)