1
0
mirror of https://github.com/gryf/coach.git synced 2026-01-08 23:04:15 +01:00

batchnorm fixes + disabling batchnorm in DDPG (#353)

Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
Gal Leibovich
2019-06-23 11:28:22 +03:00
committed by GitHub
parent 7b5d6a3f03
commit d6795bd524
22 changed files with 105 additions and 50 deletions

View File

@@ -25,17 +25,20 @@ def test_embedder(reset):
with pytest.raises(ValueError):
embedder = ImageEmbedder(np.array([10, 100, 100, 100]), name="test")
# creating a simple image embedder
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test")
# make sure the ops where not created yet
assert len(tf.get_default_graph().get_operations()) == 0
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
pre_ops = len(tf.get_default_graph().get_operations())
# creating a simple image embedder
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", is_training=is_training)
# make sure the only the is_training op is creates
assert len(tf.get_default_graph().get_operations()) == pre_ops
# call the embedder
input_ph, output_ph = embedder()
# make sure that now the ops were created
assert len(tf.get_default_graph().get_operations()) > 0
assert len(tf.get_default_graph().get_operations()) > pre_ops
# try feeding a batch of one example
input = np.random.rand(1, 100, 100, 10)
@@ -55,7 +58,9 @@ def test_embedder(reset):
@pytest.mark.unit_test
def test_complex_embedder(reset):
# creating a deep vector embedder
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep)
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep,
is_training=is_training)
# call the embedder
embedder()
@@ -71,8 +76,9 @@ def test_complex_embedder(reset):
@pytest.mark.unit_test
def test_activation_function(reset):
# creating a deep image embedder with relu
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
embedder = ImageEmbedder(np.array([100, 100, 10]), name="relu", scheme=EmbedderScheme.Deep,
activation_function=tf.nn.relu)
activation_function=tf.nn.relu, is_training=is_training)
# call the embedder
embedder()
@@ -86,7 +92,7 @@ def test_activation_function(reset):
# creating a deep image embedder with tanh
embedder_tanh = ImageEmbedder(np.array([100, 100, 10]), name="tanh", scheme=EmbedderScheme.Deep,
activation_function=tf.nn.tanh)
activation_function=tf.nn.tanh, is_training=is_training)
# call the embedder
embedder_tanh()

View File

@@ -22,16 +22,19 @@ def test_embedder(reset):
embedder = VectorEmbedder(np.array([10, 10]), name="test")
# creating a simple vector embedder
embedder = VectorEmbedder(np.array([10]), name="test")
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
pre_ops = len(tf.get_default_graph().get_operations())
embedder = VectorEmbedder(np.array([10]), name="test", is_training=is_training)
# make sure the ops where not created yet
assert len(tf.get_default_graph().get_operations()) == 0
assert len(tf.get_default_graph().get_operations()) == pre_ops
# call the embedder
input_ph, output_ph = embedder()
# make sure that now the ops were created
assert len(tf.get_default_graph().get_operations()) > 0
assert len(tf.get_default_graph().get_operations()) > pre_ops
# try feeding a batch of one example
input = np.random.rand(1, 10)
@@ -51,7 +54,8 @@ def test_embedder(reset):
@pytest.mark.unit_test
def test_complex_embedder(reset):
# creating a deep vector embedder
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep)
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep, is_training=is_training)
# call the embedder
embedder()
@@ -67,8 +71,9 @@ def test_complex_embedder(reset):
@pytest.mark.unit_test
def test_activation_function(reset):
# creating a deep vector embedder with relu
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
embedder = VectorEmbedder(np.array([10]), name="relu", scheme=EmbedderScheme.Deep,
activation_function=tf.nn.relu)
activation_function=tf.nn.relu, is_training=is_training)
# call the embedder
embedder()
@@ -82,7 +87,7 @@ def test_activation_function(reset):
# creating a deep vector embedder with tanh
embedder_tanh = VectorEmbedder(np.array([10]), name="tanh", scheme=EmbedderScheme.Deep,
activation_function=tf.nn.tanh)
activation_function=tf.nn.tanh, is_training=is_training)
# call the embedder
embedder_tanh()