1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

fixes to rainbow dqn + a cartpole based golden test (#253)

This commit is contained in:
Gal Leibovich
2019-03-21 12:57:56 +02:00
committed by GitHub
parent 83741fa92a
commit abec59f367
6 changed files with 127 additions and 23 deletions

View File

@@ -60,9 +60,12 @@ class RainbowDQNAgentParameters(CategoricalDQNAgentParameters):
def __init__(self):
super().__init__()
self.algorithm = RainbowDQNAlgorithmParameters()
# ParameterNoiseParameters is changing the network wrapper parameters. This line needs to be done first.
self.network_wrappers = {"main": RainbowDQNNetworkParameters()}
self.exploration = ParameterNoiseParameters(self)
self.memory = PrioritizedExperienceReplayParameters()
self.network_wrappers = {"main": RainbowDQNNetworkParameters()}
@property
def path(self):

View File

@@ -59,6 +59,7 @@ class InputEmbedder(object):
# layers order is conv -> batchnorm -> activation -> dropout
if isinstance(self.scheme, EmbedderScheme):
self.layers_params = copy.copy(self.schemes[self.scheme])
self.layers_params = [convert_layer(l) for l in self.layers_params]
else:
# if scheme is specified directly, convert to TF layer if it's not a callable object
# NOTE: if layer object is callable, it must return a TF tensor when invoked

View File

@@ -18,8 +18,8 @@ from typing import Type
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.losses.losses_impl import Reduction
from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.base_parameters import AgentParameters, Parameters, NetworkComponentParameters
from rl_coach.architectures.tensorflow_components.layers import Dense, convert_layer_class
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list
@@ -63,6 +63,8 @@ class Head(object):
self.dense_layer = dense_layer
if self.dense_layer is None:
self.dense_layer = Dense
else:
self.dense_layer = convert_layer_class(self.dense_layer)
def __call__(self, input_layer):
"""

View File

@@ -17,8 +17,6 @@
import math
from types import FunctionType
from typing import Any
import tensorflow as tf
from rl_coach.architectures import layers
@@ -56,9 +54,10 @@ def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dr
# define global dictionary for storing layer type to layer implementation mapping
tf_layer_dict = dict()
tf_layer_class_dict = dict()
def reg_to_tf(layer_type) -> FunctionType:
def reg_to_tf_instance(layer_type) -> FunctionType:
""" function decorator that registers layer implementation
:return: decorated function
"""
@@ -69,9 +68,20 @@ def reg_to_tf(layer_type) -> FunctionType:
return reg_impl_decorator
def reg_to_tf_class(layer_type) -> FunctionType:
""" function decorator that registers layer type
:return: decorated function
"""
def reg_impl_decorator(func):
assert layer_type not in tf_layer_class_dict
tf_layer_class_dict[layer_type] = func
return func
return reg_impl_decorator
def convert_layer(layer):
"""
If layer is callable, return layer, otherwise convert to TF type
If layer instance is callable (meaning this is already a concrete TF class), return layer, otherwise convert to TF type
:param layer: layer to be converted
:return: converted layer if not callable, otherwise layer itself
"""
@@ -80,6 +90,18 @@ def convert_layer(layer):
return tf_layer_dict[type(layer)](layer)
def convert_layer_class(layer_class):
"""
If layer instance is callable, return layer, otherwise convert to TF type
:param layer: layer to be converted
:return: converted layer if not callable, otherwise layer itself
"""
if hasattr(layer_class, 'to_tf_instance'):
return layer_class
else:
return tf_layer_class_dict[layer_class]()
class Conv2d(layers.Conv2d):
def __init__(self, num_filters: int, kernel_size: int, strides: int):
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
@@ -95,12 +117,17 @@ class Conv2d(layers.Conv2d):
strides=self.strides, data_format='channels_last', name=name)
@staticmethod
@reg_to_tf(layers.Conv2d)
def to_tf(base: layers.Conv2d):
return Conv2d(
num_filters=base.num_filters,
kernel_size=base.kernel_size,
strides=base.strides)
@reg_to_tf_instance(layers.Conv2d)
def to_tf_instance(base: layers.Conv2d):
return Conv2d(
num_filters=base.num_filters,
kernel_size=base.kernel_size,
strides=base.strides)
@staticmethod
@reg_to_tf_class(layers.Conv2d)
def to_tf_class():
return Conv2d
class BatchnormActivationDropout(layers.BatchnormActivationDropout):
@@ -121,12 +148,17 @@ class BatchnormActivationDropout(layers.BatchnormActivationDropout):
is_training=is_training, name=name)
@staticmethod
@reg_to_tf(layers.BatchnormActivationDropout)
def to_tf(base: layers.BatchnormActivationDropout):
return BatchnormActivationDropout(
batchnorm=base.batchnorm,
activation_function=base.activation_function,
dropout_rate=base.dropout_rate)
@reg_to_tf_instance(layers.BatchnormActivationDropout)
def to_tf_instance(base: layers.BatchnormActivationDropout):
return BatchnormActivationDropout, BatchnormActivationDropout(
batchnorm=base.batchnorm,
activation_function=base.activation_function,
dropout_rate=base.dropout_rate)
@staticmethod
@reg_to_tf_class(layers.BatchnormActivationDropout)
def to_tf_class():
return BatchnormActivationDropout
class Dense(layers.Dense):
@@ -144,10 +176,15 @@ class Dense(layers.Dense):
activation=activation)
@staticmethod
@reg_to_tf(layers.Dense)
def to_tf(base: layers.Dense):
@reg_to_tf_instance(layers.Dense)
def to_tf_instance(base: layers.Dense):
return Dense(units=base.units)
@staticmethod
@reg_to_tf_class(layers.Dense)
def to_tf_class():
return Dense
class NoisyNetDense(layers.NoisyNetDense):
"""
@@ -210,6 +247,11 @@ class NoisyNetDense(layers.NoisyNetDense):
return activation(tf.matmul(input_layer, weight) + bias)
@staticmethod
@reg_to_tf(layers.NoisyNetDense)
def to_tf(base: layers.NoisyNetDense):
@reg_to_tf_instance(layers.NoisyNetDense)
def to_tf_instance(base: layers.NoisyNetDense):
return NoisyNetDense(units=base.units)
@staticmethod
@reg_to_tf_class(layers.NoisyNetDense)
def to_tf_class():
return NoisyNetDense

View File

@@ -49,6 +49,7 @@ class Middleware(object):
# layers order is conv -> batchnorm -> activation -> dropout
if isinstance(self.scheme, MiddlewareScheme):
self.layers_params = copy.copy(self.schemes[self.scheme])
self.layers_params = [convert_layer(l) for l in self.layers_params]
else:
# if scheme is specified directly, convert to TF layer if it's not a callable object
# NOTE: if layer object is callable, it must return a TF tensor when invoked

View File

@@ -0,0 +1,55 @@
from rl_coach.agents.rainbow_dqn_agent import RainbowDQNAgentParameters
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.gym_environment import GymVectorEnvironment
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.memories.memory import MemoryGranularity
from rl_coach.schedules import LinearSchedule
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(10000000000)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
schedule_params.heatup_steps = EnvironmentSteps(1000)
#########
# Agent #
#########
agent_params = RainbowDQNAgentParameters()
# DQN params
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)
agent_params.algorithm.discount = 0.99
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)
# NN configuration
agent_params.network_wrappers['main'].learning_rate = 0.00025
agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
# ER size
agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000)
agent_params.memory.beta = LinearSchedule(0.4, 1, 10000)
agent_params.memory.alpha = 0.5
################
# Environment #
################
env_params = GymVectorEnvironment(level='CartPole-v0')
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 250
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=VisualizationParameters(),
preset_validation_params=preset_validation_params)