mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
initial CIL implementation (WIP)
This commit is contained in:
84
rl_coach/agents/cil_agent.py
Normal file
84
rl_coach/agents/cil_agent.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Union
|
||||
|
||||
from rl_coach.agents.imitation_agent import ImitationAgent
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import AgentParameters, MiddlewareScheme, NetworkParameters, AlgorithmParameters
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.memories.non_episodic.balanced_experience_replay import BalancedExperienceReplayParameters
|
||||
|
||||
|
||||
class CILAlgorithmParameters(AlgorithmParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.collect_new_data = False
|
||||
|
||||
|
||||
class CILNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
self.heads_parameters = [RegressionHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = False
|
||||
self.create_target_network = False
|
||||
|
||||
|
||||
class CILAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=CILAlgorithmParameters(),
|
||||
exploration=EGreedyParameters(),
|
||||
memory=BalancedExperienceReplayParameters(),
|
||||
networks={"main": CILNetworkParameters()})
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.agents.cil_agent:CILAgent'
|
||||
|
||||
|
||||
# Conditional Imitation Learning Agent: https://arxiv.org/abs/1710.02410
|
||||
class CILAgent(ImitationAgent):
|
||||
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
||||
super().__init__(agent_parameters, parent)
|
||||
self.current_high_level_control = 0
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
self.current_high_level_control = curr_state['high_level_command']
|
||||
return super().choose_action(curr_state)
|
||||
|
||||
def extract_action_values(self, prediction):
|
||||
return prediction[self.current_high_level_control].squeeze()
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
||||
|
||||
target_values = self.networks['main'].online_network.predict({**batch.states(network_keys)})
|
||||
|
||||
branch_to_update = batch.states(['high_level_command'])['high_level_command']
|
||||
for idx, branch in enumerate(branch_to_update):
|
||||
target_values[branch][idx] = batch.actions()[idx]
|
||||
|
||||
result = self.networks['main'].train_and_sync_networks({**batch.states(network_keys)}, target_values)
|
||||
total_loss, losses, unclipped_grads = result[:3]
|
||||
|
||||
return total_loss, losses, unclipped_grads
|
||||
@@ -0,0 +1,56 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
|
||||
|
||||
|
||||
class RegressionHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='q_head_params', dense_layer=Dense):
|
||||
super().__init__(parameterized_class=RegressionHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
|
||||
|
||||
class RegressionHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.name = 'regression_head'
|
||||
if isinstance(self.spaces.action, BoxActionSpace):
|
||||
self.num_actions = self.spaces.action.shape[0]
|
||||
elif isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
self.num_actions = len(self.spaces.action.actions)
|
||||
self.return_type = QActionStateValue
|
||||
if agent_parameters.network_wrappers[self.network_name].replace_mse_with_huber_loss:
|
||||
self.loss_type = tf.losses.huber_loss
|
||||
else:
|
||||
self.loss_type = tf.losses.mean_squared_error
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
self.fc1 = self.dense_layer(256)(input_layer)
|
||||
self.fc2 = self.dense_layer(256)(self.fc1)
|
||||
self.output = self.dense_layer(self.num_actions)(self.fc2, name='output')
|
||||
|
||||
|
||||
|
||||
171
rl_coach/memories/non_episodic/balanced_experience_replay.py
Normal file
171
rl_coach/memories/non_episodic/balanced_experience_replay.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import operator
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import Transition
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters, ExperienceReplay
|
||||
from rl_coach.schedules import Schedule, ConstantSchedule
|
||||
|
||||
|
||||
class BalancedExperienceReplayParameters(ExperienceReplayParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
self.allow_duplicates_in_batch_sampling = False
|
||||
self.num_classes = 0
|
||||
self.state_key_with_the_class_index = 'class'
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.memories.non_episodic.balanced_experience_replay:BalancedExperienceReplay'
|
||||
|
||||
|
||||
"""
|
||||
A replay buffer which allows sampling batches which are balanced in terms of the classes that are sampled
|
||||
"""
|
||||
class BalancedExperienceReplay(ExperienceReplay):
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool=True,
|
||||
num_classes: int=0, state_key_with_the_class_index: Any='class'):
|
||||
"""
|
||||
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
||||
:param allow_duplicates_in_batch_sampling: allow having the same transition multiple times in a batch
|
||||
:param num_classes: the number of classes in the replayed data
|
||||
:param state_key_with_the_class_index: the class index is assumed to be a value in the state dictionary.
|
||||
this parameter determines the key to retrieve the class index value
|
||||
"""
|
||||
super().__init__(max_size, allow_duplicates_in_batch_sampling)
|
||||
self.current_class_to_sample_from = 0
|
||||
self.num_classes = num_classes
|
||||
self.state_key_with_the_class_index = state_key_with_the_class_index
|
||||
self.transitions = [[] for _ in range(self.num_classes)]
|
||||
self.transitions_order = []
|
||||
|
||||
if self.num_classes < 2:
|
||||
raise ValueError("The number of classes for a balanced replay buffer should be at least 2. "
|
||||
"The number of classes that were defined are: {}".format(self.num_classes))
|
||||
|
||||
def store(self, transition: Transition, lock: bool=True) -> None:
|
||||
"""
|
||||
Store a new transition in the memory.
|
||||
:param transition: a transition to store
|
||||
:param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class
|
||||
locks and then calls store with lock = True
|
||||
:return: None
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self._num_transitions += 1
|
||||
|
||||
if self.state_key_with_the_class_index not in transition.state.keys():
|
||||
raise ValueError("The class index was not present in the state of the transition under the given key ({})"
|
||||
.format(self.state_key_with_the_class_index))
|
||||
|
||||
class_idx = transition.state[self.state_key_with_the_class_index]
|
||||
|
||||
if class_idx >= self.num_classes:
|
||||
raise ValueError("The given class index is outside the defined number of classes for the replay buffer. "
|
||||
"The given class was: {} and the number of classes defined is: {}"
|
||||
.format(class_idx, self.num_classes))
|
||||
|
||||
self.transitions[class_idx].append(transition)
|
||||
self.transitions_order.append(class_idx)
|
||||
self._enforce_max_length()
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def sample(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Sample a batch of transitions form the replay buffer. If the requested size is larger than the number
|
||||
of samples available in the replay buffer then the batch will return empty.
|
||||
:param size: the size of the batch to sample
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
if size % self.num_classes != 0:
|
||||
raise ValueError("Sampling batches from a balanced replay buffer should be done only using batch sizes "
|
||||
"which are a multiple of the number of classes. The number of classes defined is: {} "
|
||||
"and the batch size requested is: {}".format(self.num_classes, size))
|
||||
|
||||
batch_size_from_each_class = size // self.num_classes
|
||||
|
||||
if self.allow_duplicates_in_batch_sampling:
|
||||
transitions_idx = [np.random.randint(len(class_transitions), size=batch_size_from_each_class)
|
||||
for class_transitions in self.transitions]
|
||||
|
||||
else:
|
||||
for class_idx, class_transitions in enumerate(self.transitions):
|
||||
if self.num_transitions() < batch_size_from_each_class:
|
||||
raise ValueError("The replay buffer cannot be sampled since there are not enough transitions yet. "
|
||||
"There are currently {} transitions for class {}"
|
||||
.format(len(class_transitions), class_idx))
|
||||
|
||||
transitions_idx = [np.random.choice(len(class_transitions), size=batch_size_from_each_class, replace=False)
|
||||
for class_transitions in self.transitions]
|
||||
|
||||
batch = []
|
||||
for class_idx, class_transitions_idx in enumerate(transitions_idx):
|
||||
batch += [self.transitions[class_idx][i] for i in class_transitions_idx]
|
||||
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
return batch
|
||||
|
||||
def remove_transition(self, transition_index: int, lock: bool=True) -> None:
|
||||
raise ValueError("It is not possible to remove specific transitions with a balanced replay buffer")
|
||||
|
||||
def get_transition(self, transition_index: int, lock: bool=True) -> Union[None, Transition]:
|
||||
raise ValueError("It is not possible to access specific transitions with a balanced replay buffer")
|
||||
|
||||
def _enforce_max_length(self) -> None:
|
||||
"""
|
||||
Make sure that the size of the replay buffer does not pass the maximum size allowed.
|
||||
If it passes the max size, the oldest transition in the replay buffer will be removed.
|
||||
This function does not use locks since it is only called internally
|
||||
:return: None
|
||||
"""
|
||||
granularity, size = self.max_size
|
||||
if granularity == MemoryGranularity.Transitions:
|
||||
while size != 0 and self.num_transitions() > size:
|
||||
self._num_transitions -= 1
|
||||
del self.transitions[self.transitions_order[0]][0]
|
||||
del self.transitions_order[0]
|
||||
else:
|
||||
raise ValueError("The granularity of the replay buffer can only be set in terms of transitions")
|
||||
|
||||
def clean(self, lock: bool=True) -> None:
|
||||
"""
|
||||
Clean the memory by removing all the episodes
|
||||
:return: None
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self.transitions = [[] for _ in range(self.num_classes)]
|
||||
self.transitions_order = []
|
||||
self._num_transitions = 0
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
127
rl_coach/presets/CARLA_CIL.py
Normal file
127
rl_coach/presets/CARLA_CIL.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||
from rl_coach.filters.observation.observation_reduction_by_sub_parts_name_filter import \
|
||||
ObservationReductionBySubPartsNameFilter
|
||||
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
|
||||
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
|
||||
from rl_coach.schedules import ConstantSchedule
|
||||
from rl_coach.spaces import ImageObservationSpace
|
||||
|
||||
from rl_coach.agents.cil_agent import CILAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = TrainingSteps(10000000000)
|
||||
schedule_params.steps_between_evaluation_periods = TrainingSteps(500)
|
||||
schedule_params.evaluation_steps = EnvironmentEpisodes(5)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
|
||||
################
|
||||
# Agent Params #
|
||||
################
|
||||
agent_params = CILAgentParameters()
|
||||
|
||||
# forward camera and measurements input
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters = {
|
||||
'forward_camera': InputEmbedderParameters(scheme=[Conv2d([32, 5, 2]),
|
||||
Conv2d([32, 3, 1]),
|
||||
Conv2d([64, 3, 2]),
|
||||
Conv2d([64, 3, 1]),
|
||||
Conv2d([128, 3, 2]),
|
||||
Conv2d([128, 3, 1]),
|
||||
Conv2d([256, 3, 1]),
|
||||
Conv2d([256, 3, 1]),
|
||||
Dense([512]),
|
||||
Dense([512])],
|
||||
dropout=True,
|
||||
batchnorm=True),
|
||||
'measurements': InputEmbedderParameters(scheme=[Dense([128]),
|
||||
Dense([128])])
|
||||
}
|
||||
|
||||
# TODO: batch norm is currently applied to the fc layers which is not desired
|
||||
# TODO: dropout should be configured differenetly per layer [1.0] * 8 + [0.7] * 2 + [0.5] * 2 + [0.5] * 1 + [0.5, 1.] * 5
|
||||
|
||||
# simple fc middleware
|
||||
agent_params.network_wrappers['main'].middleware_parameters = FCMiddlewareParameters(scheme=[Dense([512])])
|
||||
|
||||
# output branches
|
||||
agent_params.network_wrappers['main'].heads_parameters = [
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters()
|
||||
]
|
||||
# agent_params.network_wrappers['main'].num_output_head_copies = 4 # follow lane, left, right, straight
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1, 1, 1, 1]
|
||||
agent_params.network_wrappers['main'].loss_weights = [1, 1, 1, 1]
|
||||
# TODO: there should be another head predicting the speed which is connected directly to the forward camera embedding
|
||||
|
||||
agent_params.network_wrappers['main'].batch_size = 120
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0002
|
||||
|
||||
|
||||
# crop and rescale the image + use only the forward speed measurement
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_observation_filter('forward_camera', 'cropping',
|
||||
ObservationCropFilter(crop_low=np.array([115, 0, 0]),
|
||||
crop_high=np.array([510, -1, -1])))
|
||||
agent_params.input_filter.add_observation_filter('forward_camera', 'rescale',
|
||||
ObservationRescaleToSizeFilter(
|
||||
ImageObservationSpace(np.array([88, 200, 3]), high=255)))
|
||||
agent_params.input_filter.add_observation_filter('forward_camera', 'to_uint8', ObservationToUInt8Filter(0, 255))
|
||||
agent_params.input_filter.add_observation_filter(
|
||||
'measurements', 'select_speed',
|
||||
ObservationReductionBySubPartsNameFilter(
|
||||
["forward_speed"], reduction_method=ObservationReductionBySubPartsNameFilter.ReductionMethod.Keep))
|
||||
|
||||
# no exploration is used
|
||||
agent_params.exploration = AdditiveNoiseParameters()
|
||||
agent_params.exploration.noise_percentage_schedule = ConstantSchedule(0)
|
||||
agent_params.exploration.evaluation_noise_percentage = 0
|
||||
|
||||
# no playing during the training phase
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
|
||||
|
||||
# the CARLA dataset should be downloaded through the following repository:
|
||||
# https://github.com/carla-simulator/imitation-learning
|
||||
# the dataset should then be converted to the Coach format using the script utils/carla_dataset_to_replay_buffer.py
|
||||
# the path to the converted dataset should be updated below
|
||||
agent_params.memory.load_memory_from_file_path = "./datasets/carla_train_set_replay_buffer.p"
|
||||
agent_params.memory.state_key_with_the_class_index = 'high_level_command'
|
||||
agent_params.memory.num_classes = 4
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
env_params.cameras = [CameraTypes.FRONT]
|
||||
env_params.camera_height = 600
|
||||
env_params.camera_width = 800
|
||||
env_params.allow_braking = True
|
||||
env_params.quality = CarlaEnvironmentParameters.Quality.EPIC
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = True
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
0
rl_coach/utilities/__init__.py
Normal file
0
rl_coach/utilities/__init__.py
Normal file
71
rl_coach/utilities/carla_dataset_to_replay_buffer.py
Normal file
71
rl_coach/utilities/carla_dataset_to_replay_buffer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
|
||||
import h5py
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from rl_coach.utils import ProgressBar
|
||||
from rl_coach.core_types import Transition
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplay
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser(description=__doc__)
|
||||
argparser.add_argument('-d', '--dataset_root', help='The path to the CARLA dataset root folder')
|
||||
argparser.add_argument('-o', '--output_path', help='The path to save the resulting replay buffer',
|
||||
default='carla_train_set_replay_buffer.p')
|
||||
args = argparser.parse_args()
|
||||
|
||||
train_set_root = os.path.join(args.dataset_root, 'SeqTrain')
|
||||
validation_set_root = os.path.join(args.dataset_root, 'SeqVal')
|
||||
|
||||
# training set extraction
|
||||
memory = ExperienceReplay(max_size=(MemoryGranularity.Transitions, sys.maxsize))
|
||||
train_set_files = sorted(os.listdir(train_set_root))
|
||||
print("found {} files".format(len(train_set_files)))
|
||||
progress_bar = ProgressBar(len(train_set_files))
|
||||
for file_idx, file in enumerate(train_set_files[:3000]):
|
||||
progress_bar.update(file_idx, "extracting file {}".format(file))
|
||||
train_set = h5py.File(os.path.join(train_set_root, file), 'r')
|
||||
observations = train_set['rgb'][:] # forward camera
|
||||
measurements = np.expand_dims(train_set['targets'][:, 10], -1) # forward speed
|
||||
actions = train_set['targets'][:, :3] # steer, gas, brake
|
||||
actions[:, 1] -= actions[:, 2]
|
||||
actions = actions[:, :2][:, ::-1]
|
||||
|
||||
high_level_commands = train_set['targets'][:, 24].astype('int') - 2 # follow lane, left, right, straight
|
||||
|
||||
file_length = train_set['rgb'].len()
|
||||
assert train_set['rgb'].len() == train_set['targets'].len()
|
||||
|
||||
for transition_idx in range(file_length):
|
||||
transition = Transition(
|
||||
state={
|
||||
'forward_camera': observations[transition_idx],
|
||||
'measurements': measurements[transition_idx],
|
||||
'high_level_command': high_level_commands[transition_idx]
|
||||
},
|
||||
action=actions[transition_idx],
|
||||
reward=0
|
||||
)
|
||||
memory.store(transition)
|
||||
progress_bar.close()
|
||||
print("Saving pickle file to {}".format(args.output_path))
|
||||
memory.save(args.output_path)
|
||||
Reference in New Issue
Block a user