mirror of
https://github.com/gryf/coach.git
synced 2026-02-15 13:35:55 +01:00
Trace tests update
This commit is contained in:
@@ -106,7 +106,7 @@ class CarlaEnvironment(Environment):
|
||||
server_height: int, server_width: int, camera_height: int, camera_width: int,
|
||||
verbose: bool, config: str, episode_max_time: int,
|
||||
allow_braking: bool, quality: CarlaEnvironmentParameters.Quality,
|
||||
cameras: List[CameraTypes], weather_id: List[int], **kwargs):
|
||||
cameras: List[CameraTypes], weather_id: List[int], experiment_path: str, **kwargs):
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
|
||||
|
||||
# server configuration
|
||||
@@ -115,6 +115,7 @@ class CarlaEnvironment(Environment):
|
||||
self.port = get_open_port()
|
||||
self.host = 'localhost'
|
||||
self.map = self.env_id
|
||||
self.experiment_path = experiment_path
|
||||
|
||||
# client configuration
|
||||
self.verbose = verbose
|
||||
@@ -150,8 +151,11 @@ class CarlaEnvironment(Environment):
|
||||
NumberOfVehicles=15,
|
||||
NumberOfPedestrians=30,
|
||||
WeatherId=random.choice(force_list(self.weather_id)),
|
||||
QualityLevel=self.quality.value)
|
||||
self.settings.randomize_seeds()
|
||||
QualityLevel=self.quality.value,
|
||||
SeedVehicles=seed,
|
||||
SeedPedestrians=seed)
|
||||
if seed is None:
|
||||
self.settings.randomize_seeds()
|
||||
|
||||
self.settings = self._add_cameras(self.settings, self.cameras, self.camera_width, self.camera_height)
|
||||
|
||||
@@ -260,8 +264,10 @@ class CarlaEnvironment(Environment):
|
||||
return settings
|
||||
|
||||
def _open_server(self):
|
||||
# TODO: get experiment path
|
||||
log_path = path.join('./logs/', "CARLA_LOG_{}.txt".format(self.port))
|
||||
log_path = path.join(self.experiment_path if self.experiment_path is not None else '.', 'logs',
|
||||
"CARLA_LOG_{}.txt".format(self.port))
|
||||
if not os.path.exists(os.path.dirname(log_path)):
|
||||
os.makedirs(os.path.dirname(log_path))
|
||||
with open(log_path, "wb") as out:
|
||||
cmd = [path.join(environ.get('CARLA_ROOT'), 'CarlaUE4.sh'), self.map,
|
||||
"-benchmark", "-carla-server", "-fps={}".format(30 / self.frame_skip),
|
||||
|
||||
@@ -75,7 +75,7 @@ class ControlSuiteEnvironment(Environment):
|
||||
|
||||
# load and initialize environment
|
||||
domain_name, task_name = self.env_id.split(":")
|
||||
self.env = suite.load(domain_name=domain_name, task_name=task_name)
|
||||
self.env = suite.load(domain_name=domain_name, task_name=task_name, task_kwargs={'random': seed})
|
||||
|
||||
if observation_type != ObservationType.Measurements:
|
||||
self.env = pixels.Wrapper(self.env, pixels_only=observation_type == ObservationType.Image)
|
||||
|
||||
@@ -101,6 +101,7 @@ class EnvironmentParameters(Parameters):
|
||||
self.custom_reward_threshold = None
|
||||
self.default_input_filter = None
|
||||
self.default_output_filter = None
|
||||
self.experiment_path = None
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
|
||||
Reference in New Issue
Block a user