mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
using the CoRL2017 experiment suite for CARLA_CIL
This commit is contained in:
@@ -29,6 +29,7 @@ class CILAlgorithmParameters(AlgorithmParameters):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.collect_new_data = False
|
self.collect_new_data = False
|
||||||
|
self.state_key_with_the_class_index = 'high_level_command'
|
||||||
|
|
||||||
|
|
||||||
class CILNetworkParameters(NetworkParameters):
|
class CILNetworkParameters(NetworkParameters):
|
||||||
@@ -63,7 +64,7 @@ class CILAgent(ImitationAgent):
|
|||||||
self.current_high_level_control = 0
|
self.current_high_level_control = 0
|
||||||
|
|
||||||
def choose_action(self, curr_state):
|
def choose_action(self, curr_state):
|
||||||
self.current_high_level_control = curr_state['high_level_command']
|
self.current_high_level_control = curr_state[self.ap.algorithm.state_key_with_the_class_index]
|
||||||
return super().choose_action(curr_state)
|
return super().choose_action(curr_state)
|
||||||
|
|
||||||
def extract_action_values(self, prediction):
|
def extract_action_values(self, prediction):
|
||||||
@@ -74,7 +75,7 @@ class CILAgent(ImitationAgent):
|
|||||||
|
|
||||||
target_values = self.networks['main'].online_network.predict({**batch.states(network_keys)})
|
target_values = self.networks['main'].online_network.predict({**batch.states(network_keys)})
|
||||||
|
|
||||||
branch_to_update = batch.states(['high_level_command'])['high_level_command']
|
branch_to_update = batch.states([self.ap.algorithm.state_key_with_the_class_index])[self.ap.algorithm.state_key_with_the_class_index]
|
||||||
for idx, branch in enumerate(branch_to_update):
|
for idx, branch in enumerate(branch_to_update):
|
||||||
target_values[branch][idx] = batch.actions()[idx]
|
target_values[branch][idx] = batch.actions()[idx]
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ try:
|
|||||||
from carla.sensor import Camera
|
from carla.sensor import Camera
|
||||||
from carla.client import VehicleControl
|
from carla.client import VehicleControl
|
||||||
from carla.planner.planner import Planner
|
from carla.planner.planner import Planner
|
||||||
|
from carla.driving_benchmark.experiment_suites.experiment_suite import ExperimentSuite
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from rl_coach.logger import failed_imports
|
from rl_coach.logger import failed_imports
|
||||||
failed_imports.append("CARLA")
|
failed_imports.append("CARLA")
|
||||||
@@ -103,7 +104,8 @@ class CarlaEnvironmentParameters(EnvironmentParameters):
|
|||||||
self.server_width = 720
|
self.server_width = 720
|
||||||
self.camera_height = 128
|
self.camera_height = 128
|
||||||
self.camera_width = 180
|
self.camera_width = 180
|
||||||
self.config = None #'environments/CarlaSettings.ini' # TODO: remove the config to prevent confusion
|
self.experiment_suite = None # an optional CARLA experiment suite to use
|
||||||
|
self.config = None
|
||||||
self.level = 'town1'
|
self.level = 'town1'
|
||||||
self.quality = self.Quality.LOW
|
self.quality = self.Quality.LOW
|
||||||
self.cameras = [CameraTypes.FRONT]
|
self.cameras = [CameraTypes.FRONT]
|
||||||
@@ -126,7 +128,7 @@ class CarlaEnvironment(Environment):
|
|||||||
seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float],
|
seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float],
|
||||||
visualization_parameters: VisualizationParameters,
|
visualization_parameters: VisualizationParameters,
|
||||||
server_height: int, server_width: int, camera_height: int, camera_width: int,
|
server_height: int, server_width: int, camera_height: int, camera_width: int,
|
||||||
verbose: bool, config: str, episode_max_time: int,
|
verbose: bool, experiment_suite: ExperimentSuite, config: str, episode_max_time: int,
|
||||||
allow_braking: bool, quality: CarlaEnvironmentParameters.Quality,
|
allow_braking: bool, quality: CarlaEnvironmentParameters.Quality,
|
||||||
cameras: List[CameraTypes], weather_id: List[int], experiment_path: str,
|
cameras: List[CameraTypes], weather_id: List[int], experiment_path: str,
|
||||||
num_speedup_steps: int, max_speed: float, **kwargs):
|
num_speedup_steps: int, max_speed: float, **kwargs):
|
||||||
@@ -161,6 +163,7 @@ class CarlaEnvironment(Environment):
|
|||||||
high=255)
|
high=255)
|
||||||
|
|
||||||
# setup server settings
|
# setup server settings
|
||||||
|
self.experiment_suite = experiment_suite
|
||||||
self.config = config
|
self.config = config
|
||||||
if self.config:
|
if self.config:
|
||||||
# load settings from file
|
# load settings from file
|
||||||
@@ -191,12 +194,17 @@ class CarlaEnvironment(Environment):
|
|||||||
# open the client
|
# open the client
|
||||||
self.game = CarlaClient(self.host, self.port, timeout=99999999)
|
self.game = CarlaClient(self.host, self.port, timeout=99999999)
|
||||||
self.game.connect()
|
self.game.connect()
|
||||||
scene = self.game.load_settings(self.settings)
|
if self.experiment_suite:
|
||||||
|
self.current_experiment = self.experiment_suite.get_experiments()[0]
|
||||||
|
scene = self.game.load_settings(self.current_experiment.conditions)
|
||||||
|
else:
|
||||||
|
scene = self.game.load_settings(self.settings)
|
||||||
|
|
||||||
# get available start positions
|
# get available start positions
|
||||||
self.positions = scene.player_start_spots
|
self.positions = scene.player_start_spots
|
||||||
self.num_pos = len(self.positions)
|
self.num_positions = len(self.positions)
|
||||||
self.iterator_start_positions = 0
|
self.current_start_position_idx = 0
|
||||||
|
self.current_pose = 0
|
||||||
|
|
||||||
# action space
|
# action space
|
||||||
self.action_space = BoxActionSpace(shape=2, low=np.array([-1, -1]), high=np.array([1, 1]))
|
self.action_space = BoxActionSpace(shape=2, low=np.array([-1, -1]), high=np.array([1, 1]))
|
||||||
@@ -391,18 +399,24 @@ class CarlaEnvironment(Environment):
|
|||||||
self.game.send_control(self.control)
|
self.game.send_control(self.control)
|
||||||
|
|
||||||
def _restart_environment_episode(self, force_environment_reset=False):
|
def _restart_environment_episode(self, force_environment_reset=False):
|
||||||
self.iterator_start_positions += 1
|
# select start and end positions
|
||||||
if self.iterator_start_positions >= self.num_pos:
|
if self.experiment_suite:
|
||||||
self.iterator_start_positions = 0
|
# if an expeirent suite is available, follow its given poses
|
||||||
|
self.current_start_position_idx = self.current_experiment.poses[self.current_pose][0]
|
||||||
|
self.current_goal = self.current_experiment.poses[self.current_pose][1]
|
||||||
|
self.current_pose += 1
|
||||||
|
else:
|
||||||
|
# go over all the possible positions in a cyclic manner
|
||||||
|
self.current_start_position_idx = (self.current_start_position_idx + 1) % self.num_positions
|
||||||
|
|
||||||
|
# choose a random goal destination TODO: follow the CoRL destinations and start positions
|
||||||
|
self.current_goal = random.choice(self.positions)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.game.start_episode(self.iterator_start_positions)
|
self.game.start_episode(self.current_start_position_idx)
|
||||||
except:
|
except:
|
||||||
self.game.connect()
|
self.game.connect()
|
||||||
self.game.start_episode(self.iterator_start_positions)
|
self.game.start_episode(self.current_start_position_idx)
|
||||||
|
|
||||||
# choose a random goal destination TODO: follow the CoRL destinations and start positions
|
|
||||||
self.current_goal = random.choice(self.positions)
|
|
||||||
|
|
||||||
# start the game with some initial speed
|
# start the game with some initial speed
|
||||||
for i in range(self.num_speedup_steps):
|
for i in range(self.num_speedup_steps):
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
# make sure you have $CARLA_ROOT/PythonClient in your PYTHONPATH
|
||||||
|
from carla.driving_benchmark.experiment_suites import CoRL2017
|
||||||
|
|
||||||
|
from rl_coach.agents.cil_agent import CILAgentParameters
|
||||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||||
|
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||||
|
from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters
|
||||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||||
|
from rl_coach.base_parameters import VisualizationParameters
|
||||||
|
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes
|
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes
|
||||||
|
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||||
from rl_coach.filters.filter import InputFilter
|
from rl_coach.filters.filter import InputFilter
|
||||||
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||||
@@ -12,17 +19,10 @@ from rl_coach.filters.observation.observation_reduction_by_sub_parts_name_filter
|
|||||||
ObservationReductionBySubPartsNameFilter
|
ObservationReductionBySubPartsNameFilter
|
||||||
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
|
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
|
||||||
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
|
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
|
||||||
from rl_coach.schedules import ConstantSchedule
|
|
||||||
from rl_coach.spaces import ImageObservationSpace
|
|
||||||
|
|
||||||
from rl_coach.agents.cil_agent import CILAgentParameters
|
|
||||||
from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters
|
|
||||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||||
from rl_coach.base_parameters import VisualizationParameters
|
from rl_coach.schedules import ConstantSchedule
|
||||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
from rl_coach.spaces import ImageObservationSpace
|
||||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
|
||||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Graph Scheduling #
|
# Graph Scheduling #
|
||||||
@@ -116,8 +116,9 @@ env_params.level = 'town1'
|
|||||||
env_params.cameras = [CameraTypes.FRONT]
|
env_params.cameras = [CameraTypes.FRONT]
|
||||||
env_params.camera_height = 600
|
env_params.camera_height = 600
|
||||||
env_params.camera_width = 800
|
env_params.camera_width = 800
|
||||||
env_params.allow_braking = True
|
env_params.allow_braking = False
|
||||||
env_params.quality = CarlaEnvironmentParameters.Quality.EPIC
|
env_params.quality = CarlaEnvironmentParameters.Quality.EPIC
|
||||||
|
env_params.experiment_suite = CoRL2017('Town01')
|
||||||
|
|
||||||
vis_params = VisualizationParameters()
|
vis_params = VisualizationParameters()
|
||||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||||
|
|||||||
Reference in New Issue
Block a user