1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-23 22:52:28 +01:00

Channel order transpose, for image embedder. Updated unit test. (#87)

This commit is contained in:
Thom Lane
2018-11-19 05:39:03 -08:00
committed by Gal Novik
parent ff816b347d
commit 7ba1a4393f
2 changed files with 4 additions and 4 deletions

View File

@@ -70,7 +70,6 @@ class ImageEmbedder(InputEmbedder):
:param x: image representing environment state, of shape (batch_size, in_channels, height, width).
:return: embedding of environment state, of shape (batch_size, channels).
"""
if len(x.shape) != 4 and self.scheme != EmbedderScheme.Empty:
raise ValueError("Image embedders expect the input size to have 4 dimensions. The given size is: {}"
.format(x.shape))
# convert from NHWC to NCHW (default for MXNet Convolutions)
x = x.transpose((0,3,1,2))
return super(ImageEmbedder, self).hybrid_forward(F, x, *args, **kwargs)