diff --git a/environments/doom_environment_wrapper.py b/environments/doom_environment_wrapper.py index 997a067..c4a06ac 100644 --- a/environments/doom_environment_wrapper.py +++ b/environments/doom_environment_wrapper.py @@ -146,6 +146,11 @@ class DoomEnvironmentWrapper(EnvironmentWrapper): def _preprocess_observation(self, observation): if observation is None: return None + + # for the last step we get no new observation, so we shouldn't preprocess it + if self.done: + return observation + # move the channel to the last axis observation = np.transpose(observation, (1, 2, 0)) return observation