mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
fixes to rainbow dqn + a cartpole based golden test (#253)
This commit is contained in:
@@ -17,8 +17,6 @@
|
||||
|
||||
import math
|
||||
from types import FunctionType
|
||||
from typing import Any
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures import layers
|
||||
@@ -56,9 +54,10 @@ def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dr
|
||||
|
||||
# define global dictionary for storing layer type to layer implementation mapping
|
||||
tf_layer_dict = dict()
|
||||
tf_layer_class_dict = dict()
|
||||
|
||||
|
||||
def reg_to_tf(layer_type) -> FunctionType:
|
||||
def reg_to_tf_instance(layer_type) -> FunctionType:
|
||||
""" function decorator that registers layer implementation
|
||||
:return: decorated function
|
||||
"""
|
||||
@@ -69,9 +68,20 @@ def reg_to_tf(layer_type) -> FunctionType:
|
||||
return reg_impl_decorator
|
||||
|
||||
|
||||
def reg_to_tf_class(layer_type) -> FunctionType:
|
||||
""" function decorator that registers layer type
|
||||
:return: decorated function
|
||||
"""
|
||||
def reg_impl_decorator(func):
|
||||
assert layer_type not in tf_layer_class_dict
|
||||
tf_layer_class_dict[layer_type] = func
|
||||
return func
|
||||
return reg_impl_decorator
|
||||
|
||||
|
||||
def convert_layer(layer):
|
||||
"""
|
||||
If layer is callable, return layer, otherwise convert to TF type
|
||||
If layer instance is callable (meaning this is already a concrete TF class), return layer, otherwise convert to TF type
|
||||
:param layer: layer to be converted
|
||||
:return: converted layer if not callable, otherwise layer itself
|
||||
"""
|
||||
@@ -80,6 +90,18 @@ def convert_layer(layer):
|
||||
return tf_layer_dict[type(layer)](layer)
|
||||
|
||||
|
||||
def convert_layer_class(layer_class):
|
||||
"""
|
||||
If layer instance is callable, return layer, otherwise convert to TF type
|
||||
:param layer: layer to be converted
|
||||
:return: converted layer if not callable, otherwise layer itself
|
||||
"""
|
||||
if hasattr(layer_class, 'to_tf_instance'):
|
||||
return layer_class
|
||||
else:
|
||||
return tf_layer_class_dict[layer_class]()
|
||||
|
||||
|
||||
class Conv2d(layers.Conv2d):
|
||||
def __init__(self, num_filters: int, kernel_size: int, strides: int):
|
||||
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
|
||||
@@ -95,12 +117,17 @@ class Conv2d(layers.Conv2d):
|
||||
strides=self.strides, data_format='channels_last', name=name)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf(layers.Conv2d)
|
||||
def to_tf(base: layers.Conv2d):
|
||||
return Conv2d(
|
||||
num_filters=base.num_filters,
|
||||
kernel_size=base.kernel_size,
|
||||
strides=base.strides)
|
||||
@reg_to_tf_instance(layers.Conv2d)
|
||||
def to_tf_instance(base: layers.Conv2d):
|
||||
return Conv2d(
|
||||
num_filters=base.num_filters,
|
||||
kernel_size=base.kernel_size,
|
||||
strides=base.strides)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf_class(layers.Conv2d)
|
||||
def to_tf_class():
|
||||
return Conv2d
|
||||
|
||||
|
||||
class BatchnormActivationDropout(layers.BatchnormActivationDropout):
|
||||
@@ -121,12 +148,17 @@ class BatchnormActivationDropout(layers.BatchnormActivationDropout):
|
||||
is_training=is_training, name=name)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf(layers.BatchnormActivationDropout)
|
||||
def to_tf(base: layers.BatchnormActivationDropout):
|
||||
return BatchnormActivationDropout(
|
||||
batchnorm=base.batchnorm,
|
||||
activation_function=base.activation_function,
|
||||
dropout_rate=base.dropout_rate)
|
||||
@reg_to_tf_instance(layers.BatchnormActivationDropout)
|
||||
def to_tf_instance(base: layers.BatchnormActivationDropout):
|
||||
return BatchnormActivationDropout, BatchnormActivationDropout(
|
||||
batchnorm=base.batchnorm,
|
||||
activation_function=base.activation_function,
|
||||
dropout_rate=base.dropout_rate)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf_class(layers.BatchnormActivationDropout)
|
||||
def to_tf_class():
|
||||
return BatchnormActivationDropout
|
||||
|
||||
|
||||
class Dense(layers.Dense):
|
||||
@@ -144,10 +176,15 @@ class Dense(layers.Dense):
|
||||
activation=activation)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf(layers.Dense)
|
||||
def to_tf(base: layers.Dense):
|
||||
@reg_to_tf_instance(layers.Dense)
|
||||
def to_tf_instance(base: layers.Dense):
|
||||
return Dense(units=base.units)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf_class(layers.Dense)
|
||||
def to_tf_class():
|
||||
return Dense
|
||||
|
||||
|
||||
class NoisyNetDense(layers.NoisyNetDense):
|
||||
"""
|
||||
@@ -210,6 +247,11 @@ class NoisyNetDense(layers.NoisyNetDense):
|
||||
return activation(tf.matmul(input_layer, weight) + bias)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf(layers.NoisyNetDense)
|
||||
def to_tf(base: layers.NoisyNetDense):
|
||||
@reg_to_tf_instance(layers.NoisyNetDense)
|
||||
def to_tf_instance(base: layers.NoisyNetDense):
|
||||
return NoisyNetDense(units=base.units)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_tf_class(layers.NoisyNetDense)
|
||||
def to_tf_class():
|
||||
return NoisyNetDense
|
||||
|
||||
Reference in New Issue
Block a user