mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Add tensor input type for arbitrary dimensional observation (#125)
* Allow arbitrary dimensional observation (non vector or image) * Added creating PlanarMapsObservationSpace to GymEnvironment when number of channels is not 1 or 3
This commit is contained in:
committed by
Gal Leibovich
parent
7ba1a4393f
commit
67a90ee87e
@@ -16,6 +16,7 @@
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from enum import IntEnum
|
||||
import scipy.ndimage
|
||||
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -44,7 +45,7 @@ from typing import Dict, Any, Union
|
||||
from rl_coach.core_types import RunPhase, EnvironmentSteps
|
||||
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace, \
|
||||
StateSpace, RewardSpace
|
||||
PlanarMapsObservationSpace, TensorObservationSpace, StateSpace, RewardSpace
|
||||
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
|
||||
from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter
|
||||
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
|
||||
@@ -176,11 +177,26 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
|
||||
|
||||
|
||||
# Environment
|
||||
class ObservationSpaceType(IntEnum):
|
||||
Tensor = 0
|
||||
Image = 1
|
||||
Vector = 2
|
||||
|
||||
|
||||
class GymEnvironment(Environment):
|
||||
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
||||
target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
|
||||
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
|
||||
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
|
||||
def __init__(self,
|
||||
level: LevelSelection,
|
||||
frame_skip: int,
|
||||
visualization_parameters: VisualizationParameters,
|
||||
target_success_rate: float=1.0,
|
||||
additional_simulator_parameters: Dict[str, Any] = {},
|
||||
seed: Union[None, int] = None,
|
||||
human_control: bool=False,
|
||||
custom_reward_threshold: Union[int, float]=None,
|
||||
random_initialization_steps: int=1,
|
||||
max_over_num_frames: int=1,
|
||||
observation_space_type: ObservationSpaceType=None,
|
||||
**kwargs):
|
||||
"""
|
||||
:param level: (str)
|
||||
A string representing the gym level to run. This can also be a LevelSelection object.
|
||||
@@ -215,6 +231,11 @@ class GymEnvironment(Environment):
|
||||
This value will be used for merging multiple frames into a single frame by taking the maximum value for each
|
||||
of the pixels in the frame. This is particularly used in Atari games, where the frames flicker, and objects
|
||||
can be seen in one frame but disappear in the next.
|
||||
|
||||
:param observation_space_type:
|
||||
This value will be used for generating observation space. Allows a custom space. Should be one of
|
||||
ObservationSpaceType. If not specified, observation space is inferred from the number of dimensions
|
||||
of the observation: 1D: Vector space, 3D: Image space if 1 or 3 channels, PlanarMaps space otherwise.
|
||||
"""
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
|
||||
visualization_parameters, target_success_rate)
|
||||
@@ -305,20 +326,40 @@ class GymEnvironment(Environment):
|
||||
state_space = self.env.observation_space.spaces
|
||||
|
||||
for observation_space_name, observation_space in state_space.items():
|
||||
if len(observation_space.shape) == 3:
|
||||
if observation_space_type == ObservationSpaceType.Tensor:
|
||||
# we consider arbitrary input tensor which does not necessarily represent images
|
||||
self.state_space[observation_space_name] = TensorObservationSpace(
|
||||
shape=np.array(observation_space.shape),
|
||||
low=observation_space.low,
|
||||
high=observation_space.high
|
||||
)
|
||||
elif observation_space_type == ObservationSpaceType.Image or len(observation_space.shape) == 3:
|
||||
# we assume gym has image observations (with arbitrary number of channels) where their values are
|
||||
# within 0-255, and where the channel dimension is the last dimension
|
||||
self.state_space[observation_space_name] = ImageObservationSpace(
|
||||
shape=np.array(observation_space.shape),
|
||||
high=255,
|
||||
channels_axis=-1
|
||||
)
|
||||
else:
|
||||
if observation_space.shape[-1] in [1, 3]:
|
||||
self.state_space[observation_space_name] = ImageObservationSpace(
|
||||
shape=np.array(observation_space.shape),
|
||||
high=255,
|
||||
channels_axis=-1
|
||||
)
|
||||
else:
|
||||
# For any number of channels other than 1 or 3, use the generic PlanarMaps space
|
||||
self.state_space[observation_space_name] = PlanarMapsObservationSpace(
|
||||
shape=np.array(observation_space.shape),
|
||||
low=0,
|
||||
high=255,
|
||||
channels_axis=-1
|
||||
)
|
||||
elif observation_space_type == ObservationSpaceType.Vector or len(observation_space.shape) == 1:
|
||||
self.state_space[observation_space_name] = VectorObservationSpace(
|
||||
shape=observation_space.shape[0],
|
||||
low=observation_space.low,
|
||||
high=observation_space.high
|
||||
)
|
||||
else:
|
||||
raise screen.error("Failed to instantiate Gym environment class %s with observation space type %s" %
|
||||
(env_class, observation_space_type), crash=True)
|
||||
|
||||
if 'desired_goal' in state_space.keys():
|
||||
self.goal_space = self.state_space['desired_goal']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user