1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Add tensor input type for arbitrary dimensional observation (#125)

* Allow arbitrary dimensional observation (non vector or image)
* Added creating PlanarMapsObservationSpace to GymEnvironment when number of channels is not 1 or 3
This commit is contained in:
Sina Afrooze
2018-11-19 06:41:12 -08:00
committed by Gal Leibovich
parent 7ba1a4393f
commit 67a90ee87e
10 changed files with 194 additions and 24 deletions

View File

@@ -16,6 +16,7 @@
import gym
import numpy as np
from enum import IntEnum
import scipy.ndimage
from rl_coach.graph_managers.graph_manager import ScheduleParameters
@@ -44,7 +45,7 @@ from typing import Dict, Any, Union
from rl_coach.core_types import RunPhase, EnvironmentSteps
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace, \
StateSpace, RewardSpace
PlanarMapsObservationSpace, TensorObservationSpace, StateSpace, RewardSpace
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
@@ -176,11 +177,26 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
# Environment
class ObservationSpaceType(IntEnum):
Tensor = 0
Image = 1
Vector = 2
class GymEnvironment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
def __init__(self,
level: LevelSelection,
frame_skip: int,
visualization_parameters: VisualizationParameters,
target_success_rate: float=1.0,
additional_simulator_parameters: Dict[str, Any] = {},
seed: Union[None, int] = None,
human_control: bool=False,
custom_reward_threshold: Union[int, float]=None,
random_initialization_steps: int=1,
max_over_num_frames: int=1,
observation_space_type: ObservationSpaceType=None,
**kwargs):
"""
:param level: (str)
A string representing the gym level to run. This can also be a LevelSelection object.
@@ -215,6 +231,11 @@ class GymEnvironment(Environment):
This value will be used for merging multiple frames into a single frame by taking the maximum value for each
of the pixels in the frame. This is particularly used in Atari games, where the frames flicker, and objects
can be seen in one frame but disappear in the next.
:param observation_space_type:
This value will be used for generating observation space. Allows a custom space. Should be one of
ObservationSpaceType. If not specified, observation space is inferred from the number of dimensions
of the observation: 1D: Vector space, 3D: Image space if 1 or 3 channels, PlanarMaps space otherwise.
"""
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
visualization_parameters, target_success_rate)
@@ -305,20 +326,40 @@ class GymEnvironment(Environment):
state_space = self.env.observation_space.spaces
for observation_space_name, observation_space in state_space.items():
if len(observation_space.shape) == 3:
if observation_space_type == ObservationSpaceType.Tensor:
# we consider arbitrary input tensor which does not necessarily represent images
self.state_space[observation_space_name] = TensorObservationSpace(
shape=np.array(observation_space.shape),
low=observation_space.low,
high=observation_space.high
)
elif observation_space_type == ObservationSpaceType.Image or len(observation_space.shape) == 3:
# we assume gym has image observations (with arbitrary number of channels) where their values are
# within 0-255, and where the channel dimension is the last dimension
self.state_space[observation_space_name] = ImageObservationSpace(
shape=np.array(observation_space.shape),
high=255,
channels_axis=-1
)
else:
if observation_space.shape[-1] in [1, 3]:
self.state_space[observation_space_name] = ImageObservationSpace(
shape=np.array(observation_space.shape),
high=255,
channels_axis=-1
)
else:
# For any number of channels other than 1 or 3, use the generic PlanarMaps space
self.state_space[observation_space_name] = PlanarMapsObservationSpace(
shape=np.array(observation_space.shape),
low=0,
high=255,
channels_axis=-1
)
elif observation_space_type == ObservationSpaceType.Vector or len(observation_space.shape) == 1:
self.state_space[observation_space_name] = VectorObservationSpace(
shape=observation_space.shape[0],
low=observation_space.low,
high=observation_space.high
)
else:
raise screen.error("Failed to instantiate Gym environment class %s with observation space type %s" %
(env_class, observation_space_type), crash=True)
if 'desired_goal' in state_space.keys():
self.goal_space = self.state_space['desired_goal']