1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

fixes to rainbow dqn + a cartpole based golden test (#253)

This commit is contained in:
Gal Leibovich
2019-03-21 12:57:56 +02:00
committed by GitHub
parent 83741fa92a
commit abec59f367
6 changed files with 127 additions and 23 deletions

View File

@@ -17,8 +17,6 @@
import math
from types import FunctionType
from typing import Any
import tensorflow as tf
from rl_coach.architectures import layers
@@ -56,9 +54,10 @@ def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dr
# define global dictionary for storing layer type to layer implementation mapping
tf_layer_dict = dict()
tf_layer_class_dict = dict()
def reg_to_tf(layer_type) -> FunctionType:
def reg_to_tf_instance(layer_type) -> FunctionType:
""" function decorator that registers layer implementation
:return: decorated function
"""
@@ -69,9 +68,20 @@ def reg_to_tf(layer_type) -> FunctionType:
return reg_impl_decorator
def reg_to_tf_class(layer_type) -> FunctionType:
""" function decorator that registers layer type
:return: decorated function
"""
def reg_impl_decorator(func):
assert layer_type not in tf_layer_class_dict
tf_layer_class_dict[layer_type] = func
return func
return reg_impl_decorator
def convert_layer(layer):
"""
If layer is callable, return layer, otherwise convert to TF type
If layer instance is callable (meaning this is already a concrete TF class), return layer, otherwise convert to TF type
:param layer: layer to be converted
:return: converted layer if not callable, otherwise layer itself
"""
@@ -80,6 +90,18 @@ def convert_layer(layer):
return tf_layer_dict[type(layer)](layer)
def convert_layer_class(layer_class):
"""
If layer instance is callable, return layer, otherwise convert to TF type
:param layer: layer to be converted
:return: converted layer if not callable, otherwise layer itself
"""
if hasattr(layer_class, 'to_tf_instance'):
return layer_class
else:
return tf_layer_class_dict[layer_class]()
class Conv2d(layers.Conv2d):
def __init__(self, num_filters: int, kernel_size: int, strides: int):
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
@@ -95,12 +117,17 @@ class Conv2d(layers.Conv2d):
strides=self.strides, data_format='channels_last', name=name)
@staticmethod
@reg_to_tf(layers.Conv2d)
def to_tf(base: layers.Conv2d):
return Conv2d(
num_filters=base.num_filters,
kernel_size=base.kernel_size,
strides=base.strides)
@reg_to_tf_instance(layers.Conv2d)
def to_tf_instance(base: layers.Conv2d):
return Conv2d(
num_filters=base.num_filters,
kernel_size=base.kernel_size,
strides=base.strides)
@staticmethod
@reg_to_tf_class(layers.Conv2d)
def to_tf_class():
return Conv2d
class BatchnormActivationDropout(layers.BatchnormActivationDropout):
@@ -121,12 +148,17 @@ class BatchnormActivationDropout(layers.BatchnormActivationDropout):
is_training=is_training, name=name)
@staticmethod
@reg_to_tf(layers.BatchnormActivationDropout)
def to_tf(base: layers.BatchnormActivationDropout):
return BatchnormActivationDropout(
batchnorm=base.batchnorm,
activation_function=base.activation_function,
dropout_rate=base.dropout_rate)
@reg_to_tf_instance(layers.BatchnormActivationDropout)
def to_tf_instance(base: layers.BatchnormActivationDropout):
return BatchnormActivationDropout, BatchnormActivationDropout(
batchnorm=base.batchnorm,
activation_function=base.activation_function,
dropout_rate=base.dropout_rate)
@staticmethod
@reg_to_tf_class(layers.BatchnormActivationDropout)
def to_tf_class():
return BatchnormActivationDropout
class Dense(layers.Dense):
@@ -144,10 +176,15 @@ class Dense(layers.Dense):
activation=activation)
@staticmethod
@reg_to_tf(layers.Dense)
def to_tf(base: layers.Dense):
@reg_to_tf_instance(layers.Dense)
def to_tf_instance(base: layers.Dense):
return Dense(units=base.units)
@staticmethod
@reg_to_tf_class(layers.Dense)
def to_tf_class():
return Dense
class NoisyNetDense(layers.NoisyNetDense):
"""
@@ -210,6 +247,11 @@ class NoisyNetDense(layers.NoisyNetDense):
return activation(tf.matmul(input_layer, weight) + bias)
@staticmethod
@reg_to_tf(layers.NoisyNetDense)
def to_tf(base: layers.NoisyNetDense):
@reg_to_tf_instance(layers.NoisyNetDense)
def to_tf_instance(base: layers.NoisyNetDense):
return NoisyNetDense(units=base.units)
@staticmethod
@reg_to_tf_class(layers.NoisyNetDense)
def to_tf_class():
return NoisyNetDense