mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Batch RL Tutorial (#372)
This commit is contained in:
@@ -168,15 +168,18 @@ class Dense(layers.Dense):
|
||||
def __init__(self, units: int):
|
||||
super(Dense, self).__init__(units=units)
|
||||
|
||||
def __call__(self, input_layer, name: str=None, kernel_initializer=None, activation=None, is_training=None):
|
||||
def __call__(self, input_layer, name: str=None, kernel_initializer=None, bias_initializer=None,
|
||||
activation=None, is_training=None):
|
||||
"""
|
||||
returns a tensorflow dense layer
|
||||
:param input_layer: previous layer
|
||||
:param name: layer name
|
||||
:return: dense layer
|
||||
"""
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.zeros_initializer()
|
||||
return tf.layers.dense(input_layer, self.units, name=name, kernel_initializer=kernel_initializer,
|
||||
activation=activation)
|
||||
activation=activation, bias_initializer=bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf_instance(layers.Dense)
|
||||
@@ -199,7 +202,8 @@ class NoisyNetDense(layers.NoisyNetDense):
|
||||
def __init__(self, units: int):
|
||||
super(NoisyNetDense, self).__init__(units=units)
|
||||
|
||||
def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None, is_training=None):
|
||||
def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None, is_training=None,
|
||||
bias_initializer=None):
|
||||
"""
|
||||
returns a NoisyNet dense layer
|
||||
:param input_layer: previous layer
|
||||
@@ -233,10 +237,12 @@ class NoisyNetDense(layers.NoisyNetDense):
|
||||
kernel_stddev_initializer = tf.random_uniform_initializer(-stddev * self.sigma0, stddev * self.sigma0)
|
||||
else:
|
||||
kernel_mean_initializer = kernel_stddev_initializer = kernel_initializer
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.zeros_initializer()
|
||||
with tf.variable_scope(None, default_name=name):
|
||||
weight_mean = tf.get_variable('weight_mean', shape=(num_inputs, num_outputs),
|
||||
initializer=kernel_mean_initializer)
|
||||
bias_mean = tf.get_variable('bias_mean', shape=(num_outputs,), initializer=tf.zeros_initializer())
|
||||
bias_mean = tf.get_variable('bias_mean', shape=(num_outputs,), initializer=bias_initializer)
|
||||
|
||||
weight_stddev = tf.get_variable('weight_stddev', shape=(num_inputs, num_outputs),
|
||||
initializer=kernel_stddev_initializer)
|
||||
|
||||
Reference in New Issue
Block a user