mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
moving to skimage.transform.resize (#321)
This commit is contained in:
@@ -14,39 +14,25 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from enum import Enum
|
||||
from skimage.transform import resize
|
||||
|
||||
import scipy.ndimage
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
|
||||
# imresize interpolation types as defined by scipy here:
|
||||
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
|
||||
class RescaleInterpolationType(Enum):
|
||||
NEAREST = 'nearest'
|
||||
LANCZOS = 'lanczos'
|
||||
BILINEAR = 'bilinear'
|
||||
BICUBIC = 'bicubic'
|
||||
CUBIC = 'cubic'
|
||||
|
||||
|
||||
class ObservationRescaleSizeByFactorFilter(ObservationFilter):
|
||||
"""
|
||||
Rescales an image observation by some factor. For example, the image size
|
||||
can be reduced by a factor of 2.
|
||||
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
|
||||
"""
|
||||
def __init__(self, rescale_factor: float, rescaling_interpolation_type: RescaleInterpolationType):
|
||||
def __init__(self, rescale_factor: float):
|
||||
"""
|
||||
:param rescale_factor: the factor by which the observation will be rescaled
|
||||
:param rescaling_interpolation_type: the interpolation type for rescaling
|
||||
"""
|
||||
super().__init__()
|
||||
self.rescale_factor = float(rescale_factor) # scipy requires float scale factors
|
||||
self.rescaling_interpolation_type = rescaling_interpolation_type
|
||||
self.rescale_factor = float(rescale_factor)
|
||||
# TODO: allow selecting the channels dim
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
@@ -58,13 +44,14 @@ class ObservationRescaleSizeByFactorFilter(ObservationFilter):
|
||||
raise ValueError("Observations with 3 dimensions must have 3 channels in the last axis (RGB)")
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
# scipy works only with uint8
|
||||
observation = observation.astype('uint8')
|
||||
rescaled_output_size = tuple([int(self.rescale_factor * dim) for dim in observation.shape[:2]])
|
||||
|
||||
if len(observation.shape) == 3:
|
||||
rescaled_output_size += (3,)
|
||||
|
||||
# rescale
|
||||
observation = scipy.misc.imresize(observation,
|
||||
self.rescale_factor,
|
||||
interp=self.rescaling_interpolation_type.value)
|
||||
observation = resize(observation, rescaled_output_size, anti_aliasing=False, preserve_range=True).astype('uint8')
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
@@ -15,41 +15,26 @@
|
||||
#
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
|
||||
from skimage.transform import resize
|
||||
import numpy as np
|
||||
import scipy.ndimage
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
from rl_coach.spaces import ObservationSpace, PlanarMapsObservationSpace, ImageObservationSpace
|
||||
|
||||
|
||||
# imresize interpolation types as defined by scipy here:
|
||||
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
|
||||
class RescaleInterpolationType(Enum):
|
||||
NEAREST = 'nearest'
|
||||
LANCZOS = 'lanczos'
|
||||
BILINEAR = 'bilinear'
|
||||
BICUBIC = 'bicubic'
|
||||
CUBIC = 'cubic'
|
||||
|
||||
|
||||
class ObservationRescaleToSizeFilter(ObservationFilter):
|
||||
"""
|
||||
Rescales an image observation to a given size. The target size does not
|
||||
necessarily keep the aspect ratio of the original observation.
|
||||
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
|
||||
"""
|
||||
def __init__(self, output_observation_space: PlanarMapsObservationSpace,
|
||||
rescaling_interpolation_type: RescaleInterpolationType=RescaleInterpolationType.BILINEAR):
|
||||
def __init__(self, output_observation_space: PlanarMapsObservationSpace):
|
||||
"""
|
||||
:param output_observation_space: the output observation space
|
||||
:param rescaling_interpolation_type: the interpolation type for rescaling
|
||||
"""
|
||||
super().__init__()
|
||||
self.output_observation_space = output_observation_space
|
||||
self.rescaling_interpolation_type = rescaling_interpolation_type
|
||||
|
||||
if not isinstance(output_observation_space, PlanarMapsObservationSpace):
|
||||
raise ValueError("The rescale filter only applies to observation spaces that inherit from "
|
||||
@@ -75,20 +60,19 @@ class ObservationRescaleToSizeFilter(ObservationFilter):
|
||||
self.output_observation_space.shape[self.output_observation_space.channels_axis]))
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
# scipy works only with uint8
|
||||
observation = observation.astype('uint8')
|
||||
|
||||
# rescale
|
||||
if isinstance(self.output_observation_space, ImageObservationSpace):
|
||||
observation = scipy.misc.imresize(observation,
|
||||
tuple(self.output_observation_space.shape),
|
||||
interp=self.rescaling_interpolation_type.value)
|
||||
observation = resize(observation, tuple(self.output_observation_space.shape), anti_aliasing=False,
|
||||
preserve_range=True).astype('uint8')
|
||||
|
||||
else:
|
||||
new_observation = []
|
||||
for i in range(self.output_observation_space.shape[self.output_observation_space.channels_axis]):
|
||||
new_observation.append(scipy.misc.imresize(observation.take(i, self.output_observation_space.channels_axis),
|
||||
tuple(self.planar_map_output_shape),
|
||||
interp=self.rescaling_interpolation_type.value))
|
||||
new_observation.append(resize(observation.take(i, self.output_observation_space.channels_axis),
|
||||
tuple(self.planar_map_output_shape),
|
||||
preserve_range=True).astype('uint8'))
|
||||
new_observation = np.array(new_observation)
|
||||
observation = new_observation.swapaxes(0, self.output_observation_space.channels_axis)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user