diff --git a/rl_coach/environments/gym_environment.py b/rl_coach/environments/gym_environment.py index 956185c..569c537 100644 --- a/rl_coach/environments/gym_environment.py +++ b/rl_coach/environments/gym_environment.py @@ -63,7 +63,7 @@ class GymEnvironmentParameters(EnvironmentParameters): super().__init__(level=level) self.random_initialization_steps = 0 self.max_over_num_frames = 1 - self.additional_simulator_parameters = None + self.additional_simulator_parameters = {} @property def path(self): @@ -178,7 +178,7 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper): # Environment class GymEnvironment(Environment): def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters, - additional_simulator_parameters: Dict[str, Any] = None, seed: Union[None, int]=None, + additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None, human_control: bool=False, custom_reward_threshold: Union[int, float]=None, random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs): super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, @@ -218,10 +218,12 @@ class GymEnvironment(Environment): env_class = gym.envs.registration.load(self.env_id) # instantiate the environment - if self.additional_simulator_parameters: + try: self.env = env_class(**self.additional_simulator_parameters) - else: - self.env = env_class() + except: + screen.error("Failed to instantiate Gym environment class %s with arguments %s" % + (env_class, self.additional_simulator_parameters), crash=False) + raise else: self.env = gym.make(self.env_id)