diff --git a/rl_coach/agents/cil_agent.py b/rl_coach/agents/cil_agent.py new file mode 100644 index 0000000..b6ca57e --- /dev/null +++ b/rl_coach/agents/cil_agent.py @@ -0,0 +1,84 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Union + +from rl_coach.agents.imitation_agent import ImitationAgent +from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters +from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters +from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters +from rl_coach.base_parameters import AgentParameters, MiddlewareScheme, NetworkParameters, AlgorithmParameters +from rl_coach.exploration_policies.e_greedy import EGreedyParameters +from rl_coach.memories.non_episodic.balanced_experience_replay import BalancedExperienceReplayParameters + + +class CILAlgorithmParameters(AlgorithmParameters): + def __init__(self): + super().__init__() + self.collect_new_data = False + + +class CILNetworkParameters(NetworkParameters): + def __init__(self): + super().__init__() + self.input_embedders_parameters = {'observation': InputEmbedderParameters()} + self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium) + self.heads_parameters = [RegressionHeadParameters()] + self.loss_weights = [1.0] + self.optimizer_type = 'Adam' + self.batch_size = 32 + self.replace_mse_with_huber_loss = False + self.create_target_network = False + + +class CILAgentParameters(AgentParameters): + def __init__(self): + super().__init__(algorithm=CILAlgorithmParameters(), + exploration=EGreedyParameters(), + memory=BalancedExperienceReplayParameters(), + networks={"main": CILNetworkParameters()}) + + @property + def path(self): + return 'rl_coach.agents.cil_agent:CILAgent' + + +# Conditional Imitation Learning Agent: https://arxiv.org/abs/1710.02410 +class CILAgent(ImitationAgent): + def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None): + super().__init__(agent_parameters, parent) + self.current_high_level_control = 0 + + def choose_action(self, curr_state): + self.current_high_level_control = curr_state['high_level_command'] + return super().choose_action(curr_state) + + def extract_action_values(self, prediction): + return prediction[self.current_high_level_control].squeeze() + + def learn_from_batch(self, batch): + network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() + + target_values = self.networks['main'].online_network.predict({**batch.states(network_keys)}) + + branch_to_update = batch.states(['high_level_command'])['high_level_command'] + for idx, branch in enumerate(branch_to_update): + target_values[branch][idx] = batch.actions()[idx] + + result = self.networks['main'].train_and_sync_networks({**batch.states(network_keys)}, target_values) + total_loss, losses, unclipped_grads = result[:3] + + return total_loss, losses, unclipped_grads diff --git a/rl_coach/architectures/tensorflow_components/heads/cil_head.py b/rl_coach/architectures/tensorflow_components/heads/cil_head.py new file mode 100644 index 0000000..3a699d7 --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/heads/cil_head.py @@ -0,0 +1,56 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.architecture import Dense + +from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters +from rl_coach.base_parameters import AgentParameters +from rl_coach.core_types import QActionStateValue +from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace + + +class RegressionHeadParameters(HeadParameters): + def __init__(self, activation_function: str ='relu', name: str='q_head_params', dense_layer=Dense): + super().__init__(parameterized_class=RegressionHead, activation_function=activation_function, name=name, + dense_layer=dense_layer) + + +class RegressionHead(Head): + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, + head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu', + dense_layer=Dense): + super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, + dense_layer=dense_layer) + self.name = 'regression_head' + if isinstance(self.spaces.action, BoxActionSpace): + self.num_actions = self.spaces.action.shape[0] + elif isinstance(self.spaces.action, DiscreteActionSpace): + self.num_actions = len(self.spaces.action.actions) + self.return_type = QActionStateValue + if agent_parameters.network_wrappers[self.network_name].replace_mse_with_huber_loss: + self.loss_type = tf.losses.huber_loss + else: + self.loss_type = tf.losses.mean_squared_error + + def _build_module(self, input_layer): + self.fc1 = self.dense_layer(256)(input_layer) + self.fc2 = self.dense_layer(256)(self.fc1) + self.output = self.dense_layer(self.num_actions)(self.fc2, name='output') + + + diff --git a/rl_coach/memories/non_episodic/balanced_experience_replay.py b/rl_coach/memories/non_episodic/balanced_experience_replay.py new file mode 100644 index 0000000..24f2c19 --- /dev/null +++ b/rl_coach/memories/non_episodic/balanced_experience_replay.py @@ -0,0 +1,171 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import operator +import random +from enum import Enum +from typing import List, Tuple, Any, Union + +import numpy as np + +from rl_coach.core_types import Transition +from rl_coach.memories.memory import MemoryGranularity +from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters, ExperienceReplay +from rl_coach.schedules import Schedule, ConstantSchedule + + +class BalancedExperienceReplayParameters(ExperienceReplayParameters): + def __init__(self): + super().__init__() + self.max_size = (MemoryGranularity.Transitions, 1000000) + self.allow_duplicates_in_batch_sampling = False + self.num_classes = 0 + self.state_key_with_the_class_index = 'class' + + @property + def path(self): + return 'rl_coach.memories.non_episodic.balanced_experience_replay:BalancedExperienceReplay' + + +""" +A replay buffer which allows sampling batches which are balanced in terms of the classes that are sampled +""" +class BalancedExperienceReplay(ExperienceReplay): + def __init__(self, max_size: Tuple[MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool=True, + num_classes: int=0, state_key_with_the_class_index: Any='class'): + """ + :param max_size: the maximum number of transitions or episodes to hold in the memory + :param allow_duplicates_in_batch_sampling: allow having the same transition multiple times in a batch + :param num_classes: the number of classes in the replayed data + :param state_key_with_the_class_index: the class index is assumed to be a value in the state dictionary. + this parameter determines the key to retrieve the class index value + """ + super().__init__(max_size, allow_duplicates_in_batch_sampling) + self.current_class_to_sample_from = 0 + self.num_classes = num_classes + self.state_key_with_the_class_index = state_key_with_the_class_index + self.transitions = [[] for _ in range(self.num_classes)] + self.transitions_order = [] + + if self.num_classes < 2: + raise ValueError("The number of classes for a balanced replay buffer should be at least 2. " + "The number of classes that were defined are: {}".format(self.num_classes)) + + def store(self, transition: Transition, lock: bool=True) -> None: + """ + Store a new transition in the memory. + :param transition: a transition to store + :param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class + locks and then calls store with lock = True + :return: None + """ + if lock: + self.reader_writer_lock.lock_writing_and_reading() + + self._num_transitions += 1 + + if self.state_key_with_the_class_index not in transition.state.keys(): + raise ValueError("The class index was not present in the state of the transition under the given key ({})" + .format(self.state_key_with_the_class_index)) + + class_idx = transition.state[self.state_key_with_the_class_index] + + if class_idx >= self.num_classes: + raise ValueError("The given class index is outside the defined number of classes for the replay buffer. " + "The given class was: {} and the number of classes defined is: {}" + .format(class_idx, self.num_classes)) + + self.transitions[class_idx].append(transition) + self.transitions_order.append(class_idx) + self._enforce_max_length() + + if lock: + self.reader_writer_lock.release_writing_and_reading() + + def sample(self, size: int) -> List[Transition]: + """ + Sample a batch of transitions form the replay buffer. If the requested size is larger than the number + of samples available in the replay buffer then the batch will return empty. + :param size: the size of the batch to sample + :return: a batch (list) of selected transitions from the replay buffer + """ + self.reader_writer_lock.lock_writing() + + if size % self.num_classes != 0: + raise ValueError("Sampling batches from a balanced replay buffer should be done only using batch sizes " + "which are a multiple of the number of classes. The number of classes defined is: {} " + "and the batch size requested is: {}".format(self.num_classes, size)) + + batch_size_from_each_class = size // self.num_classes + + if self.allow_duplicates_in_batch_sampling: + transitions_idx = [np.random.randint(len(class_transitions), size=batch_size_from_each_class) + for class_transitions in self.transitions] + + else: + for class_idx, class_transitions in enumerate(self.transitions): + if self.num_transitions() < batch_size_from_each_class: + raise ValueError("The replay buffer cannot be sampled since there are not enough transitions yet. " + "There are currently {} transitions for class {}" + .format(len(class_transitions), class_idx)) + + transitions_idx = [np.random.choice(len(class_transitions), size=batch_size_from_each_class, replace=False) + for class_transitions in self.transitions] + + batch = [] + for class_idx, class_transitions_idx in enumerate(transitions_idx): + batch += [self.transitions[class_idx][i] for i in class_transitions_idx] + + self.reader_writer_lock.release_writing() + + return batch + + def remove_transition(self, transition_index: int, lock: bool=True) -> None: + raise ValueError("It is not possible to remove specific transitions with a balanced replay buffer") + + def get_transition(self, transition_index: int, lock: bool=True) -> Union[None, Transition]: + raise ValueError("It is not possible to access specific transitions with a balanced replay buffer") + + def _enforce_max_length(self) -> None: + """ + Make sure that the size of the replay buffer does not pass the maximum size allowed. + If it passes the max size, the oldest transition in the replay buffer will be removed. + This function does not use locks since it is only called internally + :return: None + """ + granularity, size = self.max_size + if granularity == MemoryGranularity.Transitions: + while size != 0 and self.num_transitions() > size: + self._num_transitions -= 1 + del self.transitions[self.transitions_order[0]][0] + del self.transitions_order[0] + else: + raise ValueError("The granularity of the replay buffer can only be set in terms of transitions") + + def clean(self, lock: bool=True) -> None: + """ + Clean the memory by removing all the episodes + :return: None + """ + if lock: + self.reader_writer_lock.lock_writing_and_reading() + + self.transitions = [[] for _ in range(self.num_classes)] + self.transitions_order = [] + self._num_transitions = 0 + + if lock: + self.reader_writer_lock.release_writing_and_reading() diff --git a/rl_coach/presets/CARLA_CIL.py b/rl_coach/presets/CARLA_CIL.py new file mode 100644 index 0000000..147c0b7 --- /dev/null +++ b/rl_coach/presets/CARLA_CIL.py @@ -0,0 +1,127 @@ +import os +import sys + +import numpy as np +from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense +from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters +from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes +from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters +from rl_coach.filters.filter import InputFilter +from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter +from rl_coach.filters.observation.observation_reduction_by_sub_parts_name_filter import \ + ObservationReductionBySubPartsNameFilter +from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter +from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter +from rl_coach.schedules import ConstantSchedule +from rl_coach.spaces import ImageObservationSpace + +from rl_coach.agents.cil_agent import CILAgentParameters +from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters +from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.base_parameters import VisualizationParameters +from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters +from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase +from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod + +#################### +# Graph Scheduling # +#################### +schedule_params = ScheduleParameters() +schedule_params.improve_steps = TrainingSteps(10000000000) +schedule_params.steps_between_evaluation_periods = TrainingSteps(500) +schedule_params.evaluation_steps = EnvironmentEpisodes(5) +schedule_params.heatup_steps = EnvironmentSteps(0) + +################ +# Agent Params # +################ +agent_params = CILAgentParameters() + +# forward camera and measurements input +agent_params.network_wrappers['main'].input_embedders_parameters = { + 'forward_camera': InputEmbedderParameters(scheme=[Conv2d([32, 5, 2]), + Conv2d([32, 3, 1]), + Conv2d([64, 3, 2]), + Conv2d([64, 3, 1]), + Conv2d([128, 3, 2]), + Conv2d([128, 3, 1]), + Conv2d([256, 3, 1]), + Conv2d([256, 3, 1]), + Dense([512]), + Dense([512])], + dropout=True, + batchnorm=True), + 'measurements': InputEmbedderParameters(scheme=[Dense([128]), + Dense([128])]) +} + +# TODO: batch norm is currently applied to the fc layers which is not desired +# TODO: dropout should be configured differenetly per layer [1.0] * 8 + [0.7] * 2 + [0.5] * 2 + [0.5] * 1 + [0.5, 1.] * 5 + +# simple fc middleware +agent_params.network_wrappers['main'].middleware_parameters = FCMiddlewareParameters(scheme=[Dense([512])]) + +# output branches +agent_params.network_wrappers['main'].heads_parameters = [ + RegressionHeadParameters(), + RegressionHeadParameters(), + RegressionHeadParameters(), + RegressionHeadParameters() +] +# agent_params.network_wrappers['main'].num_output_head_copies = 4 # follow lane, left, right, straight +agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1, 1, 1, 1] +agent_params.network_wrappers['main'].loss_weights = [1, 1, 1, 1] +# TODO: there should be another head predicting the speed which is connected directly to the forward camera embedding + +agent_params.network_wrappers['main'].batch_size = 120 +agent_params.network_wrappers['main'].learning_rate = 0.0002 + + +# crop and rescale the image + use only the forward speed measurement +agent_params.input_filter = InputFilter() +agent_params.input_filter.add_observation_filter('forward_camera', 'cropping', + ObservationCropFilter(crop_low=np.array([115, 0, 0]), + crop_high=np.array([510, -1, -1]))) +agent_params.input_filter.add_observation_filter('forward_camera', 'rescale', + ObservationRescaleToSizeFilter( + ImageObservationSpace(np.array([88, 200, 3]), high=255))) +agent_params.input_filter.add_observation_filter('forward_camera', 'to_uint8', ObservationToUInt8Filter(0, 255)) +agent_params.input_filter.add_observation_filter( + 'measurements', 'select_speed', + ObservationReductionBySubPartsNameFilter( + ["forward_speed"], reduction_method=ObservationReductionBySubPartsNameFilter.ReductionMethod.Keep)) + +# no exploration is used +agent_params.exploration = AdditiveNoiseParameters() +agent_params.exploration.noise_percentage_schedule = ConstantSchedule(0) +agent_params.exploration.evaluation_noise_percentage = 0 + +# no playing during the training phase +agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0) + +# the CARLA dataset should be downloaded through the following repository: +# https://github.com/carla-simulator/imitation-learning +# the dataset should then be converted to the Coach format using the script utils/carla_dataset_to_replay_buffer.py +# the path to the converted dataset should be updated below +agent_params.memory.load_memory_from_file_path = "./datasets/carla_train_set_replay_buffer.p" +agent_params.memory.state_key_with_the_class_index = 'high_level_command' +agent_params.memory.num_classes = 4 + +############### +# Environment # +############### +env_params = CarlaEnvironmentParameters() +env_params.level = 'town1' +env_params.cameras = [CameraTypes.FRONT] +env_params.camera_height = 600 +env_params.camera_width = 800 +env_params.allow_braking = True +env_params.quality = CarlaEnvironmentParameters.Quality.EPIC + +vis_params = VisualizationParameters() +vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] +vis_params.dump_mp4 = True + +graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, + schedule_params=schedule_params, vis_params=vis_params) diff --git a/rl_coach/utilities/__init__.py b/rl_coach/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl_coach/utilities/carla_dataset_to_replay_buffer.py b/rl_coach/utilities/carla_dataset_to_replay_buffer.py new file mode 100644 index 0000000..2a67006 --- /dev/null +++ b/rl_coach/utilities/carla_dataset_to_replay_buffer.py @@ -0,0 +1,71 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +import h5py +import os +import sys +import numpy as np +from rl_coach.utils import ProgressBar +from rl_coach.core_types import Transition +from rl_coach.memories.memory import MemoryGranularity +from rl_coach.memories.non_episodic.experience_replay import ExperienceReplay + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser(description=__doc__) + argparser.add_argument('-d', '--dataset_root', help='The path to the CARLA dataset root folder') + argparser.add_argument('-o', '--output_path', help='The path to save the resulting replay buffer', + default='carla_train_set_replay_buffer.p') + args = argparser.parse_args() + + train_set_root = os.path.join(args.dataset_root, 'SeqTrain') + validation_set_root = os.path.join(args.dataset_root, 'SeqVal') + + # training set extraction + memory = ExperienceReplay(max_size=(MemoryGranularity.Transitions, sys.maxsize)) + train_set_files = sorted(os.listdir(train_set_root)) + print("found {} files".format(len(train_set_files))) + progress_bar = ProgressBar(len(train_set_files)) + for file_idx, file in enumerate(train_set_files[:3000]): + progress_bar.update(file_idx, "extracting file {}".format(file)) + train_set = h5py.File(os.path.join(train_set_root, file), 'r') + observations = train_set['rgb'][:] # forward camera + measurements = np.expand_dims(train_set['targets'][:, 10], -1) # forward speed + actions = train_set['targets'][:, :3] # steer, gas, brake + actions[:, 1] -= actions[:, 2] + actions = actions[:, :2][:, ::-1] + + high_level_commands = train_set['targets'][:, 24].astype('int') - 2 # follow lane, left, right, straight + + file_length = train_set['rgb'].len() + assert train_set['rgb'].len() == train_set['targets'].len() + + for transition_idx in range(file_length): + transition = Transition( + state={ + 'forward_camera': observations[transition_idx], + 'measurements': measurements[transition_idx], + 'high_level_command': high_level_commands[transition_idx] + }, + action=actions[transition_idx], + reward=0 + ) + memory.store(transition) + progress_bar.close() + print("Saving pickle file to {}".format(args.output_path)) + memory.save(args.output_path)