1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

initial CIL implementation (WIP)

This commit is contained in:
itaicaspi-intel
2018-09-13 15:29:29 +03:00
parent 99649c1626
commit d3f97cd93b
6 changed files with 509 additions and 0 deletions

View File

@@ -0,0 +1,84 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Union
from rl_coach.agents.imitation_agent import ImitationAgent
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AgentParameters, MiddlewareScheme, NetworkParameters, AlgorithmParameters
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.non_episodic.balanced_experience_replay import BalancedExperienceReplayParameters
class CILAlgorithmParameters(AlgorithmParameters):
def __init__(self):
super().__init__()
self.collect_new_data = False
class CILNetworkParameters(NetworkParameters):
def __init__(self):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
self.heads_parameters = [RegressionHeadParameters()]
self.loss_weights = [1.0]
self.optimizer_type = 'Adam'
self.batch_size = 32
self.replace_mse_with_huber_loss = False
self.create_target_network = False
class CILAgentParameters(AgentParameters):
def __init__(self):
super().__init__(algorithm=CILAlgorithmParameters(),
exploration=EGreedyParameters(),
memory=BalancedExperienceReplayParameters(),
networks={"main": CILNetworkParameters()})
@property
def path(self):
return 'rl_coach.agents.cil_agent:CILAgent'
# Conditional Imitation Learning Agent: https://arxiv.org/abs/1710.02410
class CILAgent(ImitationAgent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
super().__init__(agent_parameters, parent)
self.current_high_level_control = 0
def choose_action(self, curr_state):
self.current_high_level_control = curr_state['high_level_command']
return super().choose_action(curr_state)
def extract_action_values(self, prediction):
return prediction[self.current_high_level_control].squeeze()
def learn_from_batch(self, batch):
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
target_values = self.networks['main'].online_network.predict({**batch.states(network_keys)})
branch_to_update = batch.states(['high_level_command'])['high_level_command']
for idx, branch in enumerate(branch_to_update):
target_values[branch][idx] = batch.actions()[idx]
result = self.networks['main'].train_and_sync_networks({**batch.states(network_keys)}, target_values)
total_loss, losses, unclipped_grads = result[:3]
return total_loss, losses, unclipped_grads

View File

@@ -0,0 +1,56 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
class RegressionHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='q_head_params', dense_layer=Dense):
super().__init__(parameterized_class=RegressionHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class RegressionHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'regression_head'
if isinstance(self.spaces.action, BoxActionSpace):
self.num_actions = self.spaces.action.shape[0]
elif isinstance(self.spaces.action, DiscreteActionSpace):
self.num_actions = len(self.spaces.action.actions)
self.return_type = QActionStateValue
if agent_parameters.network_wrappers[self.network_name].replace_mse_with_huber_loss:
self.loss_type = tf.losses.huber_loss
else:
self.loss_type = tf.losses.mean_squared_error
def _build_module(self, input_layer):
self.fc1 = self.dense_layer(256)(input_layer)
self.fc2 = self.dense_layer(256)(self.fc1)
self.output = self.dense_layer(self.num_actions)(self.fc2, name='output')

View File

@@ -0,0 +1,171 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
import random
from enum import Enum
from typing import List, Tuple, Any, Union
import numpy as np
from rl_coach.core_types import Transition
from rl_coach.memories.memory import MemoryGranularity
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters, ExperienceReplay
from rl_coach.schedules import Schedule, ConstantSchedule
class BalancedExperienceReplayParameters(ExperienceReplayParameters):
def __init__(self):
super().__init__()
self.max_size = (MemoryGranularity.Transitions, 1000000)
self.allow_duplicates_in_batch_sampling = False
self.num_classes = 0
self.state_key_with_the_class_index = 'class'
@property
def path(self):
return 'rl_coach.memories.non_episodic.balanced_experience_replay:BalancedExperienceReplay'
"""
A replay buffer which allows sampling batches which are balanced in terms of the classes that are sampled
"""
class BalancedExperienceReplay(ExperienceReplay):
def __init__(self, max_size: Tuple[MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool=True,
num_classes: int=0, state_key_with_the_class_index: Any='class'):
"""
:param max_size: the maximum number of transitions or episodes to hold in the memory
:param allow_duplicates_in_batch_sampling: allow having the same transition multiple times in a batch
:param num_classes: the number of classes in the replayed data
:param state_key_with_the_class_index: the class index is assumed to be a value in the state dictionary.
this parameter determines the key to retrieve the class index value
"""
super().__init__(max_size, allow_duplicates_in_batch_sampling)
self.current_class_to_sample_from = 0
self.num_classes = num_classes
self.state_key_with_the_class_index = state_key_with_the_class_index
self.transitions = [[] for _ in range(self.num_classes)]
self.transitions_order = []
if self.num_classes < 2:
raise ValueError("The number of classes for a balanced replay buffer should be at least 2. "
"The number of classes that were defined are: {}".format(self.num_classes))
def store(self, transition: Transition, lock: bool=True) -> None:
"""
Store a new transition in the memory.
:param transition: a transition to store
:param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class
locks and then calls store with lock = True
:return: None
"""
if lock:
self.reader_writer_lock.lock_writing_and_reading()
self._num_transitions += 1
if self.state_key_with_the_class_index not in transition.state.keys():
raise ValueError("The class index was not present in the state of the transition under the given key ({})"
.format(self.state_key_with_the_class_index))
class_idx = transition.state[self.state_key_with_the_class_index]
if class_idx >= self.num_classes:
raise ValueError("The given class index is outside the defined number of classes for the replay buffer. "
"The given class was: {} and the number of classes defined is: {}"
.format(class_idx, self.num_classes))
self.transitions[class_idx].append(transition)
self.transitions_order.append(class_idx)
self._enforce_max_length()
if lock:
self.reader_writer_lock.release_writing_and_reading()
def sample(self, size: int) -> List[Transition]:
"""
Sample a batch of transitions form the replay buffer. If the requested size is larger than the number
of samples available in the replay buffer then the batch will return empty.
:param size: the size of the batch to sample
:return: a batch (list) of selected transitions from the replay buffer
"""
self.reader_writer_lock.lock_writing()
if size % self.num_classes != 0:
raise ValueError("Sampling batches from a balanced replay buffer should be done only using batch sizes "
"which are a multiple of the number of classes. The number of classes defined is: {} "
"and the batch size requested is: {}".format(self.num_classes, size))
batch_size_from_each_class = size // self.num_classes
if self.allow_duplicates_in_batch_sampling:
transitions_idx = [np.random.randint(len(class_transitions), size=batch_size_from_each_class)
for class_transitions in self.transitions]
else:
for class_idx, class_transitions in enumerate(self.transitions):
if self.num_transitions() < batch_size_from_each_class:
raise ValueError("The replay buffer cannot be sampled since there are not enough transitions yet. "
"There are currently {} transitions for class {}"
.format(len(class_transitions), class_idx))
transitions_idx = [np.random.choice(len(class_transitions), size=batch_size_from_each_class, replace=False)
for class_transitions in self.transitions]
batch = []
for class_idx, class_transitions_idx in enumerate(transitions_idx):
batch += [self.transitions[class_idx][i] for i in class_transitions_idx]
self.reader_writer_lock.release_writing()
return batch
def remove_transition(self, transition_index: int, lock: bool=True) -> None:
raise ValueError("It is not possible to remove specific transitions with a balanced replay buffer")
def get_transition(self, transition_index: int, lock: bool=True) -> Union[None, Transition]:
raise ValueError("It is not possible to access specific transitions with a balanced replay buffer")
def _enforce_max_length(self) -> None:
"""
Make sure that the size of the replay buffer does not pass the maximum size allowed.
If it passes the max size, the oldest transition in the replay buffer will be removed.
This function does not use locks since it is only called internally
:return: None
"""
granularity, size = self.max_size
if granularity == MemoryGranularity.Transitions:
while size != 0 and self.num_transitions() > size:
self._num_transitions -= 1
del self.transitions[self.transitions_order[0]][0]
del self.transitions_order[0]
else:
raise ValueError("The granularity of the replay buffer can only be set in terms of transitions")
def clean(self, lock: bool=True) -> None:
"""
Clean the memory by removing all the episodes
:return: None
"""
if lock:
self.reader_writer_lock.lock_writing_and_reading()
self.transitions = [[] for _ in range(self.num_classes)]
self.transitions_order = []
self._num_transitions = 0
if lock:
self.reader_writer_lock.release_writing_and_reading()

View File

@@ -0,0 +1,127 @@
import os
import sys
import numpy as np
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.filters.filter import InputFilter
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
from rl_coach.filters.observation.observation_reduction_by_sub_parts_name_filter import \
ObservationReductionBySubPartsNameFilter
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
from rl_coach.schedules import ConstantSchedule
from rl_coach.spaces import ImageObservationSpace
from rl_coach.agents.cil_agent import CILAgentParameters
from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.base_parameters import VisualizationParameters
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(10000000000)
schedule_params.steps_between_evaluation_periods = TrainingSteps(500)
schedule_params.evaluation_steps = EnvironmentEpisodes(5)
schedule_params.heatup_steps = EnvironmentSteps(0)
################
# Agent Params #
################
agent_params = CILAgentParameters()
# forward camera and measurements input
agent_params.network_wrappers['main'].input_embedders_parameters = {
'forward_camera': InputEmbedderParameters(scheme=[Conv2d([32, 5, 2]),
Conv2d([32, 3, 1]),
Conv2d([64, 3, 2]),
Conv2d([64, 3, 1]),
Conv2d([128, 3, 2]),
Conv2d([128, 3, 1]),
Conv2d([256, 3, 1]),
Conv2d([256, 3, 1]),
Dense([512]),
Dense([512])],
dropout=True,
batchnorm=True),
'measurements': InputEmbedderParameters(scheme=[Dense([128]),
Dense([128])])
}
# TODO: batch norm is currently applied to the fc layers which is not desired
# TODO: dropout should be configured differenetly per layer [1.0] * 8 + [0.7] * 2 + [0.5] * 2 + [0.5] * 1 + [0.5, 1.] * 5
# simple fc middleware
agent_params.network_wrappers['main'].middleware_parameters = FCMiddlewareParameters(scheme=[Dense([512])])
# output branches
agent_params.network_wrappers['main'].heads_parameters = [
RegressionHeadParameters(),
RegressionHeadParameters(),
RegressionHeadParameters(),
RegressionHeadParameters()
]
# agent_params.network_wrappers['main'].num_output_head_copies = 4 # follow lane, left, right, straight
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1, 1, 1, 1]
agent_params.network_wrappers['main'].loss_weights = [1, 1, 1, 1]
# TODO: there should be another head predicting the speed which is connected directly to the forward camera embedding
agent_params.network_wrappers['main'].batch_size = 120
agent_params.network_wrappers['main'].learning_rate = 0.0002
# crop and rescale the image + use only the forward speed measurement
agent_params.input_filter = InputFilter()
agent_params.input_filter.add_observation_filter('forward_camera', 'cropping',
ObservationCropFilter(crop_low=np.array([115, 0, 0]),
crop_high=np.array([510, -1, -1])))
agent_params.input_filter.add_observation_filter('forward_camera', 'rescale',
ObservationRescaleToSizeFilter(
ImageObservationSpace(np.array([88, 200, 3]), high=255)))
agent_params.input_filter.add_observation_filter('forward_camera', 'to_uint8', ObservationToUInt8Filter(0, 255))
agent_params.input_filter.add_observation_filter(
'measurements', 'select_speed',
ObservationReductionBySubPartsNameFilter(
["forward_speed"], reduction_method=ObservationReductionBySubPartsNameFilter.ReductionMethod.Keep))
# no exploration is used
agent_params.exploration = AdditiveNoiseParameters()
agent_params.exploration.noise_percentage_schedule = ConstantSchedule(0)
agent_params.exploration.evaluation_noise_percentage = 0
# no playing during the training phase
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
# the CARLA dataset should be downloaded through the following repository:
# https://github.com/carla-simulator/imitation-learning
# the dataset should then be converted to the Coach format using the script utils/carla_dataset_to_replay_buffer.py
# the path to the converted dataset should be updated below
agent_params.memory.load_memory_from_file_path = "./datasets/carla_train_set_replay_buffer.p"
agent_params.memory.state_key_with_the_class_index = 'high_level_command'
agent_params.memory.num_classes = 4
###############
# Environment #
###############
env_params = CarlaEnvironmentParameters()
env_params.level = 'town1'
env_params.cameras = [CameraTypes.FRONT]
env_params.camera_height = 600
env_params.camera_width = 800
env_params.allow_braking = True
env_params.quality = CarlaEnvironmentParameters.Quality.EPIC
vis_params = VisualizationParameters()
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
vis_params.dump_mp4 = True
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params)

View File

View File

@@ -0,0 +1,71 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import h5py
import os
import sys
import numpy as np
from rl_coach.utils import ProgressBar
from rl_coach.core_types import Transition
from rl_coach.memories.memory import MemoryGranularity
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplay
if __name__ == "__main__":
argparser = argparse.ArgumentParser(description=__doc__)
argparser.add_argument('-d', '--dataset_root', help='The path to the CARLA dataset root folder')
argparser.add_argument('-o', '--output_path', help='The path to save the resulting replay buffer',
default='carla_train_set_replay_buffer.p')
args = argparser.parse_args()
train_set_root = os.path.join(args.dataset_root, 'SeqTrain')
validation_set_root = os.path.join(args.dataset_root, 'SeqVal')
# training set extraction
memory = ExperienceReplay(max_size=(MemoryGranularity.Transitions, sys.maxsize))
train_set_files = sorted(os.listdir(train_set_root))
print("found {} files".format(len(train_set_files)))
progress_bar = ProgressBar(len(train_set_files))
for file_idx, file in enumerate(train_set_files[:3000]):
progress_bar.update(file_idx, "extracting file {}".format(file))
train_set = h5py.File(os.path.join(train_set_root, file), 'r')
observations = train_set['rgb'][:] # forward camera
measurements = np.expand_dims(train_set['targets'][:, 10], -1) # forward speed
actions = train_set['targets'][:, :3] # steer, gas, brake
actions[:, 1] -= actions[:, 2]
actions = actions[:, :2][:, ::-1]
high_level_commands = train_set['targets'][:, 24].astype('int') - 2 # follow lane, left, right, straight
file_length = train_set['rgb'].len()
assert train_set['rgb'].len() == train_set['targets'].len()
for transition_idx in range(file_length):
transition = Transition(
state={
'forward_camera': observations[transition_idx],
'measurements': measurements[transition_idx],
'high_level_command': high_level_commands[transition_idx]
},
action=actions[transition_idx],
reward=0
)
memory.store(transition)
progress_bar.close()
print("Saving pickle file to {}".format(args.output_path))
memory.save(args.output_path)