diff --git a/rl_coach/agents/rainbow_dqn_agent.py b/rl_coach/agents/rainbow_dqn_agent.py index 4973670..08e417c 100644 --- a/rl_coach/agents/rainbow_dqn_agent.py +++ b/rl_coach/agents/rainbow_dqn_agent.py @@ -60,9 +60,12 @@ class RainbowDQNAgentParameters(CategoricalDQNAgentParameters): def __init__(self): super().__init__() self.algorithm = RainbowDQNAlgorithmParameters() + + # ParameterNoiseParameters is changing the network wrapper parameters. This line needs to be done first. + self.network_wrappers = {"main": RainbowDQNNetworkParameters()} + self.exploration = ParameterNoiseParameters(self) self.memory = PrioritizedExperienceReplayParameters() - self.network_wrappers = {"main": RainbowDQNNetworkParameters()} @property def path(self): diff --git a/rl_coach/architectures/tensorflow_components/embedders/embedder.py b/rl_coach/architectures/tensorflow_components/embedders/embedder.py index b3c1924..5bcdef2 100644 --- a/rl_coach/architectures/tensorflow_components/embedders/embedder.py +++ b/rl_coach/architectures/tensorflow_components/embedders/embedder.py @@ -59,6 +59,7 @@ class InputEmbedder(object): # layers order is conv -> batchnorm -> activation -> dropout if isinstance(self.scheme, EmbedderScheme): self.layers_params = copy.copy(self.schemes[self.scheme]) + self.layers_params = [convert_layer(l) for l in self.layers_params] else: # if scheme is specified directly, convert to TF layer if it's not a callable object # NOTE: if layer object is callable, it must return a TF tensor when invoked diff --git a/rl_coach/architectures/tensorflow_components/heads/head.py b/rl_coach/architectures/tensorflow_components/heads/head.py index 956bd5c..397c8ab 100644 --- a/rl_coach/architectures/tensorflow_components/heads/head.py +++ b/rl_coach/architectures/tensorflow_components/heads/head.py @@ -18,8 +18,8 @@ from typing import Type import numpy as np import tensorflow as tf from tensorflow.python.ops.losses.losses_impl import Reduction -from rl_coach.architectures.tensorflow_components.layers import Dense -from rl_coach.base_parameters import AgentParameters, Parameters, NetworkComponentParameters +from rl_coach.architectures.tensorflow_components.layers import Dense, convert_layer_class +from rl_coach.base_parameters import AgentParameters from rl_coach.spaces import SpacesDefinition from rl_coach.utils import force_list @@ -63,6 +63,8 @@ class Head(object): self.dense_layer = dense_layer if self.dense_layer is None: self.dense_layer = Dense + else: + self.dense_layer = convert_layer_class(self.dense_layer) def __call__(self, input_layer): """ diff --git a/rl_coach/architectures/tensorflow_components/layers.py b/rl_coach/architectures/tensorflow_components/layers.py index 39e9980..81a1992 100644 --- a/rl_coach/architectures/tensorflow_components/layers.py +++ b/rl_coach/architectures/tensorflow_components/layers.py @@ -17,8 +17,6 @@ import math from types import FunctionType -from typing import Any - import tensorflow as tf from rl_coach.architectures import layers @@ -56,9 +54,10 @@ def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dr # define global dictionary for storing layer type to layer implementation mapping tf_layer_dict = dict() +tf_layer_class_dict = dict() -def reg_to_tf(layer_type) -> FunctionType: +def reg_to_tf_instance(layer_type) -> FunctionType: """ function decorator that registers layer implementation :return: decorated function """ @@ -69,9 +68,20 @@ def reg_to_tf(layer_type) -> FunctionType: return reg_impl_decorator +def reg_to_tf_class(layer_type) -> FunctionType: + """ function decorator that registers layer type + :return: decorated function + """ + def reg_impl_decorator(func): + assert layer_type not in tf_layer_class_dict + tf_layer_class_dict[layer_type] = func + return func + return reg_impl_decorator + + def convert_layer(layer): """ - If layer is callable, return layer, otherwise convert to TF type + If layer instance is callable (meaning this is already a concrete TF class), return layer, otherwise convert to TF type :param layer: layer to be converted :return: converted layer if not callable, otherwise layer itself """ @@ -80,6 +90,18 @@ def convert_layer(layer): return tf_layer_dict[type(layer)](layer) +def convert_layer_class(layer_class): + """ + If layer instance is callable, return layer, otherwise convert to TF type + :param layer: layer to be converted + :return: converted layer if not callable, otherwise layer itself + """ + if hasattr(layer_class, 'to_tf_instance'): + return layer_class + else: + return tf_layer_class_dict[layer_class]() + + class Conv2d(layers.Conv2d): def __init__(self, num_filters: int, kernel_size: int, strides: int): super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides) @@ -95,12 +117,17 @@ class Conv2d(layers.Conv2d): strides=self.strides, data_format='channels_last', name=name) @staticmethod - @reg_to_tf(layers.Conv2d) - def to_tf(base: layers.Conv2d): - return Conv2d( - num_filters=base.num_filters, - kernel_size=base.kernel_size, - strides=base.strides) + @reg_to_tf_instance(layers.Conv2d) + def to_tf_instance(base: layers.Conv2d): + return Conv2d( + num_filters=base.num_filters, + kernel_size=base.kernel_size, + strides=base.strides) + + @staticmethod + @reg_to_tf_class(layers.Conv2d) + def to_tf_class(): + return Conv2d class BatchnormActivationDropout(layers.BatchnormActivationDropout): @@ -121,12 +148,17 @@ class BatchnormActivationDropout(layers.BatchnormActivationDropout): is_training=is_training, name=name) @staticmethod - @reg_to_tf(layers.BatchnormActivationDropout) - def to_tf(base: layers.BatchnormActivationDropout): - return BatchnormActivationDropout( - batchnorm=base.batchnorm, - activation_function=base.activation_function, - dropout_rate=base.dropout_rate) + @reg_to_tf_instance(layers.BatchnormActivationDropout) + def to_tf_instance(base: layers.BatchnormActivationDropout): + return BatchnormActivationDropout, BatchnormActivationDropout( + batchnorm=base.batchnorm, + activation_function=base.activation_function, + dropout_rate=base.dropout_rate) + + @staticmethod + @reg_to_tf_class(layers.BatchnormActivationDropout) + def to_tf_class(): + return BatchnormActivationDropout class Dense(layers.Dense): @@ -144,10 +176,15 @@ class Dense(layers.Dense): activation=activation) @staticmethod - @reg_to_tf(layers.Dense) - def to_tf(base: layers.Dense): + @reg_to_tf_instance(layers.Dense) + def to_tf_instance(base: layers.Dense): return Dense(units=base.units) + @staticmethod + @reg_to_tf_class(layers.Dense) + def to_tf_class(): + return Dense + class NoisyNetDense(layers.NoisyNetDense): """ @@ -210,6 +247,11 @@ class NoisyNetDense(layers.NoisyNetDense): return activation(tf.matmul(input_layer, weight) + bias) @staticmethod - @reg_to_tf(layers.NoisyNetDense) - def to_tf(base: layers.NoisyNetDense): + @reg_to_tf_instance(layers.NoisyNetDense) + def to_tf_instance(base: layers.NoisyNetDense): return NoisyNetDense(units=base.units) + + @staticmethod + @reg_to_tf_class(layers.NoisyNetDense) + def to_tf_class(): + return NoisyNetDense diff --git a/rl_coach/architectures/tensorflow_components/middlewares/middleware.py b/rl_coach/architectures/tensorflow_components/middlewares/middleware.py index 6fe0727..64c578f 100644 --- a/rl_coach/architectures/tensorflow_components/middlewares/middleware.py +++ b/rl_coach/architectures/tensorflow_components/middlewares/middleware.py @@ -49,6 +49,7 @@ class Middleware(object): # layers order is conv -> batchnorm -> activation -> dropout if isinstance(self.scheme, MiddlewareScheme): self.layers_params = copy.copy(self.schemes[self.scheme]) + self.layers_params = [convert_layer(l) for l in self.layers_params] else: # if scheme is specified directly, convert to TF layer if it's not a callable object # NOTE: if layer object is callable, it must return a TF tensor when invoked diff --git a/rl_coach/presets/CartPole_Rainbow.py b/rl_coach/presets/CartPole_Rainbow.py new file mode 100644 index 0000000..9389712 --- /dev/null +++ b/rl_coach/presets/CartPole_Rainbow.py @@ -0,0 +1,55 @@ +from rl_coach.agents.rainbow_dqn_agent import RainbowDQNAgentParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps +from rl_coach.environments.gym_environment import GymVectorEnvironment +from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.memories.memory import MemoryGranularity +from rl_coach.schedules import LinearSchedule + +#################### +# Graph Scheduling # +#################### + +schedule_params = ScheduleParameters() +schedule_params.improve_steps = TrainingSteps(10000000000) +schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10) +schedule_params.evaluation_steps = EnvironmentEpisodes(1) +schedule_params.heatup_steps = EnvironmentSteps(1000) + +######### +# Agent # +######### +agent_params = RainbowDQNAgentParameters() + +# DQN params +agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100) +agent_params.algorithm.discount = 0.99 +agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1) + +# NN configuration +agent_params.network_wrappers['main'].learning_rate = 0.00025 +agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False + +# ER size +agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000) + +agent_params.memory.beta = LinearSchedule(0.4, 1, 10000) +agent_params.memory.alpha = 0.5 + +################ +# Environment # +################ +env_params = GymVectorEnvironment(level='CartPole-v0') + +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test = True +preset_validation_params.min_reward_threshold = 150 +preset_validation_params.max_episodes_to_achieve_reward = 250 + +graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, + schedule_params=schedule_params, vis_params=VisualizationParameters(), + preset_validation_params=preset_validation_params)