diff --git a/rl_coach/environments/gym_environment.py b/rl_coach/environments/gym_environment.py index ad97961..956185c 100644 --- a/rl_coach/environments/gym_environment.py +++ b/rl_coach/environments/gym_environment.py @@ -267,8 +267,9 @@ class GymEnvironment(Environment): state_space = self.env.observation_space.spaces for observation_space_name, observation_space in state_space.items(): - if len(observation_space.shape) == 3 and observation_space.shape[-1] == 3: - # we assume gym has image observations which are RGB and where their values are within 0-255 + if len(observation_space.shape) == 3: + # we assume gym has image observations (with arbitrary number of channels) where their values are + # within 0-255, and where the channel dimension is the last dimension self.state_space[observation_space_name] = ImageObservationSpace( shape=np.array(observation_space.shape), high=255,