mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Channel order transpose, for image embedder. Updated unit test. (#87)
This commit is contained in:
@@ -70,7 +70,6 @@ class ImageEmbedder(InputEmbedder):
|
|||||||
:param x: image representing environment state, of shape (batch_size, in_channels, height, width).
|
:param x: image representing environment state, of shape (batch_size, in_channels, height, width).
|
||||||
:return: embedding of environment state, of shape (batch_size, channels).
|
:return: embedding of environment state, of shape (batch_size, channels).
|
||||||
"""
|
"""
|
||||||
if len(x.shape) != 4 and self.scheme != EmbedderScheme.Empty:
|
# convert from NHWC to NCHW (default for MXNet Convolutions)
|
||||||
raise ValueError("Image embedders expect the input size to have 4 dimensions. The given size is: {}"
|
x = x.transpose((0,3,1,2))
|
||||||
.format(x.shape))
|
|
||||||
return super(ImageEmbedder, self).hybrid_forward(F, x, *args, **kwargs)
|
return super(ImageEmbedder, self).hybrid_forward(F, x, *args, **kwargs)
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ def test_image_embedder():
|
|||||||
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
|
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
|
||||||
emb = ImageEmbedder(params=params)
|
emb = ImageEmbedder(params=params)
|
||||||
emb.initialize()
|
emb.initialize()
|
||||||
input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 3, 244, 244))
|
# input is NHWC, and not MXNet default NCHW
|
||||||
|
input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 244, 244, 3))
|
||||||
output = emb(input_data)
|
output = emb(input_data)
|
||||||
assert len(output.shape) == 2 # since last block was flatten
|
assert len(output.shape) == 2 # since last block was flatten
|
||||||
assert output.shape[0] == 10 # since batch_size is 10
|
assert output.shape[0] == 10 # since batch_size is 10
|
||||||
|
|||||||
Reference in New Issue
Block a user