mirror of
https://github.com/gryf/coach.git
synced 2026-05-01 05:04:10 +02:00
network_imporvements branch merge
This commit is contained in:
@@ -1,19 +1,19 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from logger import screen
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
# make sure you have $CARLA_ROOT/PythonClient in your PYTHONPATH
|
||||
from carla.driving_benchmark.experiment_suites import CoRL2017
|
||||
from rl_coach.logger import screen
|
||||
|
||||
from rl_coach.agents.cil_agent import CILAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense, BatchnormActivationDropout
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||
@@ -27,7 +27,6 @@ from rl_coach.schedules import ConstantSchedule
|
||||
from rl_coach.spaces import ImageObservationSpace
|
||||
from rl_coach.utilities.carla_dataset_to_replay_buffer import create_dataset
|
||||
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
@@ -44,38 +43,64 @@ agent_params = CILAgentParameters()
|
||||
|
||||
# forward camera and measurements input
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters = {
|
||||
'CameraRGB': InputEmbedderParameters(scheme=[Conv2d([32, 5, 2]),
|
||||
Conv2d([32, 3, 1]),
|
||||
Conv2d([64, 3, 2]),
|
||||
Conv2d([64, 3, 1]),
|
||||
Conv2d([128, 3, 2]),
|
||||
Conv2d([128, 3, 1]),
|
||||
Conv2d([256, 3, 1]),
|
||||
Conv2d([256, 3, 1]),
|
||||
Dense([512]),
|
||||
Dense([512])],
|
||||
dropout=True,
|
||||
batchnorm=True),
|
||||
'measurements': InputEmbedderParameters(scheme=[Dense([128]),
|
||||
Dense([128])])
|
||||
'CameraRGB': InputEmbedderParameters(
|
||||
scheme=[
|
||||
Conv2d(32, 5, 2),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(32, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(64, 3, 2),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(64, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(128, 3, 2),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(128, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(256, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(256, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Dense(512),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.3),
|
||||
Dense(512),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.3)
|
||||
],
|
||||
activation_function='none' # we define the activation function for each layer explicitly
|
||||
),
|
||||
'measurements': InputEmbedderParameters(
|
||||
scheme=[
|
||||
Dense(128),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5),
|
||||
Dense(128),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5)
|
||||
],
|
||||
activation_function='none' # we define the activation function for each layer explicitly
|
||||
)
|
||||
}
|
||||
|
||||
# TODO: batch norm is currently applied to the fc layers which is not desired
|
||||
# TODO: dropout should be configured differenetly per layer [1.0] * 8 + [0.7] * 2 + [0.5] * 2 + [0.5] * 1 + [0.5, 1.] * 5
|
||||
|
||||
# simple fc middleware
|
||||
agent_params.network_wrappers['main'].middleware_parameters = FCMiddlewareParameters(scheme=[Dense([512])])
|
||||
agent_params.network_wrappers['main'].middleware_parameters = \
|
||||
FCMiddlewareParameters(
|
||||
scheme=[
|
||||
Dense(512),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5)
|
||||
],
|
||||
activation_function='none'
|
||||
)
|
||||
|
||||
# output branches
|
||||
agent_params.network_wrappers['main'].heads_parameters = [
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters()
|
||||
RegressionHeadParameters(
|
||||
scheme=[
|
||||
Dense(256),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5),
|
||||
Dense(256),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh)
|
||||
],
|
||||
num_output_head_copies=4 # follow lane, left, right, straight
|
||||
)
|
||||
]
|
||||
# agent_params.network_wrappers['main'].num_output_head_copies = 4 # follow lane, left, right, straight
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1, 1, 1, 1]
|
||||
agent_params.network_wrappers['main'].loss_weights = [1, 1, 1, 1]
|
||||
# TODO: there should be another head predicting the speed which is connected directly to the forward camera embedding
|
||||
|
||||
agent_params.network_wrappers['main'].batch_size = 120
|
||||
@@ -125,7 +150,6 @@ if not os.path.exists(agent_params.memory.load_memory_from_file_path):
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
env_params.cameras = ['CameraRGB']
|
||||
env_params.camera_height = 600
|
||||
env_params.camera_width = 800
|
||||
@@ -134,9 +158,5 @@ env_params.allow_braking = True
|
||||
env_params.quality = CarlaEnvironmentParameters.Quality.EPIC
|
||||
env_params.experiment_suite = CoRL2017('Town01')
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST)]
|
||||
vis_params.dump_mp4 = True
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
Reference in New Issue
Block a user