mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
updated CARLA to allow using actions of size 3 + automatic downloading of the CARLA imitation dataset
This commit is contained in:
@@ -15,26 +15,47 @@
|
||||
#
|
||||
|
||||
import argparse
|
||||
|
||||
import h5py
|
||||
import os
|
||||
import sys
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
from rl_coach.utils import ProgressBar
|
||||
|
||||
from rl_coach.core_types import Transition
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplay
|
||||
from rl_coach.utils import ProgressBar, start_shell_command_and_wait
|
||||
from rl_coach.logger import screen
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser(description=__doc__)
|
||||
argparser.add_argument('-d', '--dataset_root', help='The path to the CARLA dataset root folder')
|
||||
argparser.add_argument('-o', '--output_path', help='The path to save the resulting replay buffer',
|
||||
default='carla_train_set_replay_buffer.p')
|
||||
args = argparser.parse_args()
|
||||
def maybe_download(dataset_root):
|
||||
if not dataset_root or not os.path.exists(dataset_root):
|
||||
screen.log_title("Downloading the CARLA dataset. This might take a while.")
|
||||
|
||||
train_set_root = os.path.join(args.dataset_root, 'SeqTrain')
|
||||
validation_set_root = os.path.join(args.dataset_root, 'SeqVal')
|
||||
google_drive_download_id = "1hloAeyamYn-H6MfV1dRtY1gJPhkR55sY"
|
||||
filename_to_save = "datasets/CARLA_dataset.tar.gz"
|
||||
download_command = 'wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=' \
|
||||
'$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies ' \
|
||||
'--no-check-certificate \"https://docs.google.com/uc?export=download&id={}\" -O- | ' \
|
||||
'sed -rn \'s/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p\')&id={}" -O {} && rm -rf /tmp/cookies.txt'\
|
||||
.format(google_drive_download_id, google_drive_download_id, filename_to_save)
|
||||
|
||||
# start downloading and wait for it to finish
|
||||
start_shell_command_and_wait(download_command)
|
||||
|
||||
screen.log_title("Unzipping the dataset")
|
||||
unzip_command = 'tar -xzf {}'.format(filename_to_save)
|
||||
if dataset_root is not None:
|
||||
unzip_command += " -C {}".format(dataset_root)
|
||||
|
||||
start_shell_command_and_wait(unzip_command)
|
||||
|
||||
|
||||
def create_dataset(dataset_root, output_path):
|
||||
maybe_download(dataset_root)
|
||||
|
||||
train_set_root = os.path.join(dataset_root, 'SeqTrain')
|
||||
validation_set_root = os.path.join(dataset_root, 'SeqVal')
|
||||
|
||||
# training set extraction
|
||||
memory = ExperienceReplay(max_size=(MemoryGranularity.Transitions, sys.maxsize))
|
||||
@@ -44,11 +65,13 @@ if __name__ == "__main__":
|
||||
for file_idx, file in enumerate(train_set_files[:3000]):
|
||||
progress_bar.update(file_idx, "extracting file {}".format(file))
|
||||
train_set = h5py.File(os.path.join(train_set_root, file), 'r')
|
||||
observations = train_set['rgb'][:] # forward camera
|
||||
measurements = np.expand_dims(train_set['targets'][:, 10], -1) # forward speed
|
||||
actions = train_set['targets'][:, :3] # steer, gas, brake
|
||||
actions[:, 1] -= actions[:, 2]
|
||||
actions = actions[:, :2][:, ::-1]
|
||||
observations = train_set['rgb'][:] # forward camera
|
||||
measurements = np.expand_dims(train_set['targets'][:, 10], -1) # forward speed
|
||||
actions = train_set['targets'][:, :3] # steer, gas, brake
|
||||
# actions[:, :2] = actions[:, 1:3]
|
||||
# actions[:, 2] = train_set['targets'][:, 0] # gas, brake, steer
|
||||
# actions[:, 1] -= actions[:, 2]
|
||||
# actions = actions[:, :2][:, ::-1]
|
||||
|
||||
high_level_commands = train_set['targets'][:, 24].astype('int') - 2 # follow lane, left, right, straight
|
||||
|
||||
@@ -67,5 +90,15 @@ if __name__ == "__main__":
|
||||
)
|
||||
memory.store(transition)
|
||||
progress_bar.close()
|
||||
print("Saving pickle file to {}".format(args.output_path))
|
||||
memory.save(args.output_path)
|
||||
print("Saving pickle file to {}".format(output_path))
|
||||
memory.save(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser(description=__doc__)
|
||||
argparser.add_argument('-d', '--dataset_root', help='The path to the CARLA dataset root folder')
|
||||
argparser.add_argument('-o', '--output_path', help='The path to save the resulting replay buffer',
|
||||
default='carla_train_set_replay_buffer.p')
|
||||
args = argparser.parse_args()
|
||||
|
||||
create_dataset(args.dataset_root, args.output_path)
|
||||
Reference in New Issue
Block a user