mirror of
https://github.com/gryf/coach.git
synced 2026-01-08 23:04:15 +01:00
batchnorm fixes + disabling batchnorm in DDPG (#353)
Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
@@ -25,17 +25,20 @@ def test_embedder(reset):
|
||||
with pytest.raises(ValueError):
|
||||
embedder = ImageEmbedder(np.array([10, 100, 100, 100]), name="test")
|
||||
|
||||
# creating a simple image embedder
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test")
|
||||
|
||||
# make sure the ops where not created yet
|
||||
assert len(tf.get_default_graph().get_operations()) == 0
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
pre_ops = len(tf.get_default_graph().get_operations())
|
||||
# creating a simple image embedder
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", is_training=is_training)
|
||||
|
||||
# make sure the only the is_training op is creates
|
||||
assert len(tf.get_default_graph().get_operations()) == pre_ops
|
||||
|
||||
# call the embedder
|
||||
input_ph, output_ph = embedder()
|
||||
|
||||
# make sure that now the ops were created
|
||||
assert len(tf.get_default_graph().get_operations()) > 0
|
||||
assert len(tf.get_default_graph().get_operations()) > pre_ops
|
||||
|
||||
# try feeding a batch of one example
|
||||
input = np.random.rand(1, 100, 100, 10)
|
||||
@@ -55,7 +58,9 @@ def test_embedder(reset):
|
||||
@pytest.mark.unit_test
|
||||
def test_complex_embedder(reset):
|
||||
# creating a deep vector embedder
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep)
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep,
|
||||
is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder()
|
||||
@@ -71,8 +76,9 @@ def test_complex_embedder(reset):
|
||||
@pytest.mark.unit_test
|
||||
def test_activation_function(reset):
|
||||
# creating a deep image embedder with relu
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="relu", scheme=EmbedderScheme.Deep,
|
||||
activation_function=tf.nn.relu)
|
||||
activation_function=tf.nn.relu, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder()
|
||||
@@ -86,7 +92,7 @@ def test_activation_function(reset):
|
||||
|
||||
# creating a deep image embedder with tanh
|
||||
embedder_tanh = ImageEmbedder(np.array([100, 100, 10]), name="tanh", scheme=EmbedderScheme.Deep,
|
||||
activation_function=tf.nn.tanh)
|
||||
activation_function=tf.nn.tanh, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder_tanh()
|
||||
|
||||
@@ -22,16 +22,19 @@ def test_embedder(reset):
|
||||
embedder = VectorEmbedder(np.array([10, 10]), name="test")
|
||||
|
||||
# creating a simple vector embedder
|
||||
embedder = VectorEmbedder(np.array([10]), name="test")
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
pre_ops = len(tf.get_default_graph().get_operations())
|
||||
|
||||
embedder = VectorEmbedder(np.array([10]), name="test", is_training=is_training)
|
||||
|
||||
# make sure the ops where not created yet
|
||||
assert len(tf.get_default_graph().get_operations()) == 0
|
||||
assert len(tf.get_default_graph().get_operations()) == pre_ops
|
||||
|
||||
# call the embedder
|
||||
input_ph, output_ph = embedder()
|
||||
|
||||
# make sure that now the ops were created
|
||||
assert len(tf.get_default_graph().get_operations()) > 0
|
||||
assert len(tf.get_default_graph().get_operations()) > pre_ops
|
||||
|
||||
# try feeding a batch of one example
|
||||
input = np.random.rand(1, 10)
|
||||
@@ -51,7 +54,8 @@ def test_embedder(reset):
|
||||
@pytest.mark.unit_test
|
||||
def test_complex_embedder(reset):
|
||||
# creating a deep vector embedder
|
||||
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep)
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder()
|
||||
@@ -67,8 +71,9 @@ def test_complex_embedder(reset):
|
||||
@pytest.mark.unit_test
|
||||
def test_activation_function(reset):
|
||||
# creating a deep vector embedder with relu
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
embedder = VectorEmbedder(np.array([10]), name="relu", scheme=EmbedderScheme.Deep,
|
||||
activation_function=tf.nn.relu)
|
||||
activation_function=tf.nn.relu, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder()
|
||||
@@ -82,7 +87,7 @@ def test_activation_function(reset):
|
||||
|
||||
# creating a deep vector embedder with tanh
|
||||
embedder_tanh = VectorEmbedder(np.array([10]), name="tanh", scheme=EmbedderScheme.Deep,
|
||||
activation_function=tf.nn.tanh)
|
||||
activation_function=tf.nn.tanh, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder_tanh()
|
||||
|
||||
Reference in New Issue
Block a user