diff --git a/rl_coach/environments/carla_environment.py b/rl_coach/environments/carla_environment.py index c33e4ff..ed95055 100644 --- a/rl_coach/environments/carla_environment.py +++ b/rl_coach/environments/carla_environment.py @@ -153,15 +153,6 @@ class CarlaEnvironment(Environment): self.camera_width = camera_width self.camera_height = camera_height - # state space - self.state_space = StateSpace({ - "measurements": VectorObservationSpace(4, measurements_names=["forward_speed", "x", "y", "z"]) - }) - for camera in self.cameras: - self.state_space[camera.value] = ImageObservationSpace( - shape=np.array([self.camera_height, self.camera_width, 3]), - high=255) - # setup server settings self.experiment_suite = experiment_suite self.config = config @@ -195,17 +186,27 @@ class CarlaEnvironment(Environment): self.game = CarlaClient(self.host, self.port, timeout=99999999) self.game.connect() if self.experiment_suite: - self.current_experiment = self.experiment_suite.get_experiments()[0] - scene = self.game.load_settings(self.current_experiment.conditions) + self.current_experiment_idx = 0 + self.current_experiment = self.experiment_suite.get_experiments()[self.current_experiment_idx] + self.scene = self.game.load_settings(self.current_experiment.conditions) else: - scene = self.game.load_settings(self.settings) + self.scene = self.game.load_settings(self.settings) # get available start positions - self.positions = scene.player_start_spots + self.positions = self.scene.player_start_spots self.num_positions = len(self.positions) self.current_start_position_idx = 0 self.current_pose = 0 + # state space + self.state_space = StateSpace({ + "measurements": VectorObservationSpace(4, measurements_names=["forward_speed", "x", "y", "z"]) + }) + for camera in self.scene.sensors: + self.state_space[camera.name] = ImageObservationSpace( + shape=np.array([self.camera_height, self.camera_width, 3]), + high=255) + # action space self.action_space = BoxActionSpace(shape=2, low=np.array([-1, -1]), high=np.array([1, 1])) @@ -342,8 +343,8 @@ class CarlaEnvironment(Environment): measurements, sensor_data = self.game.read_data() self.state = {} - for camera in self.cameras: - self.state[camera.value] = sensor_data[camera.value].data + for camera in self.scene.sensors: + self.state[camera.name] = sensor_data[camera.name].data self.location = [measurements.player_measurements.transform.location.x, measurements.player_measurements.transform.location.y, @@ -398,12 +399,25 @@ class CarlaEnvironment(Environment): self.game.send_control(self.control) + def _load_experiment(self, experiment_idx): + self.current_experiment = self.experiment_suite.get_experiments()[experiment_idx] + self.scene = self.game.load_settings(self.current_experiment.conditions) + self.positions = self.scene.player_start_spots + self.num_positions = len(self.positions) + self.current_start_position_idx = 0 + self.current_pose = 0 + def _restart_environment_episode(self, force_environment_reset=False): # select start and end positions if self.experiment_suite: # if an expeirent suite is available, follow its given poses + if self.current_pose >= len(self.current_experiment.poses): + # load a new experiment + self.current_experiment_idx = (self.current_experiment_idx + 1) % len(self.experiment_suite.get_experiments()) + self._load_experiment(self.current_experiment_idx) + self.current_start_position_idx = self.current_experiment.poses[self.current_pose][0] - self.current_goal = self.current_experiment.poses[self.current_pose][1] + self.current_goal = self.positions[self.current_experiment.poses[self.current_pose][1]] self.current_pose += 1 else: # go over all the possible positions in a cyclic manner @@ -428,6 +442,6 @@ class CarlaEnvironment(Environment): This can be different from the observation. For example, mujoco's observation is a measurements vector. :return: numpy array containing the image that will be rendered to the screen """ - image = [self.state[camera.value] for camera in self.cameras] + image = [self.state[camera.name] for camera in self.scene.sensors] image = np.vstack(image) return image diff --git a/rl_coach/presets/CARLA_CIL.py b/rl_coach/presets/CARLA_CIL.py index cdc1f36..e4bf20c 100644 --- a/rl_coach/presets/CARLA_CIL.py +++ b/rl_coach/presets/CARLA_CIL.py @@ -40,7 +40,7 @@ agent_params = CILAgentParameters() # forward camera and measurements input agent_params.network_wrappers['main'].input_embedders_parameters = { - 'forward_camera': InputEmbedderParameters(scheme=[Conv2d([32, 5, 2]), + 'CameraRGB': InputEmbedderParameters(scheme=[Conv2d([32, 5, 2]), Conv2d([32, 3, 1]), Conv2d([64, 3, 2]), Conv2d([64, 3, 1]), @@ -80,13 +80,13 @@ agent_params.network_wrappers['main'].learning_rate = 0.0002 # crop and rescale the image + use only the forward speed measurement agent_params.input_filter = InputFilter() -agent_params.input_filter.add_observation_filter('forward_camera', 'cropping', +agent_params.input_filter.add_observation_filter('CameraRGB', 'cropping', ObservationCropFilter(crop_low=np.array([115, 0, 0]), crop_high=np.array([510, -1, -1]))) -agent_params.input_filter.add_observation_filter('forward_camera', 'rescale', +agent_params.input_filter.add_observation_filter('CameraRGB', 'rescale', ObservationRescaleToSizeFilter( ImageObservationSpace(np.array([88, 200, 3]), high=255))) -agent_params.input_filter.add_observation_filter('forward_camera', 'to_uint8', ObservationToUInt8Filter(0, 255)) +agent_params.input_filter.add_observation_filter('CameraRGB', 'to_uint8', ObservationToUInt8Filter(0, 255)) agent_params.input_filter.add_observation_filter( 'measurements', 'select_speed', ObservationReductionBySubPartsNameFilter( @@ -113,7 +113,7 @@ agent_params.memory.num_classes = 4 ############### env_params = CarlaEnvironmentParameters() env_params.level = 'town1' -env_params.cameras = [CameraTypes.FRONT] +env_params.cameras = ['CameraRGB'] env_params.camera_height = 600 env_params.camera_width = 800 env_params.allow_braking = False diff --git a/rl_coach/utilities/carla_dataset_to_replay_buffer.py b/rl_coach/utilities/carla_dataset_to_replay_buffer.py index 2a67006..207b57e 100644 --- a/rl_coach/utilities/carla_dataset_to_replay_buffer.py +++ b/rl_coach/utilities/carla_dataset_to_replay_buffer.py @@ -58,7 +58,7 @@ if __name__ == "__main__": for transition_idx in range(file_length): transition = Transition( state={ - 'forward_camera': observations[transition_idx], + 'CameraRGB': observations[transition_idx], 'measurements': measurements[transition_idx], 'high_level_command': high_level_commands[transition_idx] },