1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-13 12:25:47 +01:00

moving to skimage.transform.resize (#321)

This commit is contained in:
Gal Leibovich
2019-05-23 13:38:01 +03:00
committed by GitHub
parent acceb03ac0
commit 30c2b2fc45
5 changed files with 35 additions and 67 deletions

View File

@@ -5,7 +5,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import pytest
import numpy as np
from rl_coach.filters.observation.observation_rescale_size_by_factor_filter import ObservationRescaleSizeByFactorFilter, RescaleInterpolationType
from rl_coach.filters.observation.observation_rescale_size_by_factor_filter import ObservationRescaleSizeByFactorFilter
from rl_coach.spaces import ObservationSpace
from rl_coach.core_types import EnvResponse
from rl_coach.filters.filter import InputFilter
@@ -17,7 +17,7 @@ def test_filter():
env_response = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleSizeByFactorFilter(0.5, RescaleInterpolationType.BILINEAR))
ObservationRescaleSizeByFactorFilter(0.5))
result = rescale_filter.filter(env_response)[0]
unfiltered_observation = env_response.next_state['observation']
@@ -33,7 +33,7 @@ def test_filter():
env_response = EnvResponse(next_state={'observation': np.ones([20, 30])}, reward=0, game_over=False)
rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleSizeByFactorFilter(2, RescaleInterpolationType.BILINEAR))
ObservationRescaleSizeByFactorFilter(2))
result = rescale_filter.filter(env_response)[0]
filtered_observation = result.next_state['observation']
@@ -47,7 +47,7 @@ def test_get_filtered_observation_space():
# error on wrong number of channels
rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleSizeByFactorFilter(0.5, RescaleInterpolationType.BILINEAR))
ObservationRescaleSizeByFactorFilter(0.5))
observation_space = ObservationSpace(np.array([10, 20, 5]))
with pytest.raises(ValueError):
filtered_observation_space = rescale_filter.get_filtered_observation_space('observation', observation_space)
@@ -64,3 +64,6 @@ def test_get_filtered_observation_space():
# make sure the original observation space is unchanged
assert np.all(observation_space.shape == np.array([10, 20, 3]))
if __name__ == '__main__':
test_filter()

View File

@@ -5,7 +5,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import pytest
import numpy as np
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
from rl_coach.spaces import ObservationSpace, ImageObservationSpace, PlanarMapsObservationSpace
from rl_coach.core_types import EnvResponse
from rl_coach.filters.filter import InputFilter
@@ -18,9 +18,8 @@ def test_filter():
transition = EnvResponse(next_state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 3]),
high=255),
RescaleInterpolationType.BILINEAR))
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 3]),
high=255)))
result = rescale_filter.filter(transition)[0]
unfiltered_observation = transition.next_state['observation']
@@ -38,8 +37,7 @@ def test_filter():
rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([40, 60]),
high=255),
RescaleInterpolationType.BILINEAR))
high=255)))
result = rescale_filter.filter(transition)[0]
filtered_observation = result.next_state['observation']
@@ -52,21 +50,20 @@ def test_filter():
# InputFilter(
# observation_filters=OrderedDict([('rescale',
# ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([10, 20, 1]),
# high=255),
# RescaleInterpolationType.BILINEAR))]))
# high=255)
# ))]))
# TODO: validate input to filter
# different number of axes -> error
# env_response = EnvResponse(state={'observation': np.ones([20, 30, 3])}, reward=0, game_over=False)
# rescale_filter = ObservationRescaleToSizeFilter(ObservationSpace(np.array([10, 20])),
# RescaleInterpolationType.BILINEAR)
# rescale_filter = ObservationRescaleToSizeFilter(ObservationSpace(np.array([10, 20]))
# )
# with pytest.raises(ValueError):
# result = rescale_filter.filter(transition)
# channels first -> error
with pytest.raises(ValueError):
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([3, 10, 20]), high=255),
RescaleInterpolationType.BILINEAR)
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([3, 10, 20]), high=255))
@pytest.mark.unit_test
@@ -76,15 +73,13 @@ def test_get_filtered_observation_space():
observation_filters = InputFilter()
observation_filters.add_observation_filter('observation', 'rescale',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 5]),
high=255),
RescaleInterpolationType.BILINEAR))
high=255)))
# mismatch and wrong number of channels
rescale_filter = InputFilter()
rescale_filter.add_observation_filter('observation', 'rescale',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([5, 10, 3]),
high=255),
RescaleInterpolationType.BILINEAR))
high=255)))
observation_space = PlanarMapsObservationSpace(np.array([10, 20, 5]), low=0, high=255)
with pytest.raises(ValueError):

View File

@@ -4,7 +4,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import pytest
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter
from rl_coach.filters.observation.observation_stacking_filter import ObservationStackingFilter
@@ -31,7 +31,6 @@ def test_filter_stacking():
filter1 = ObservationRescaleToSizeFilter(
output_observation_space=ImageObservationSpace(np.array([110, 84]), high=255),
rescaling_interpolation_type=RescaleInterpolationType.BILINEAR
)
filter2 = ObservationCropFilter(