1

my code and output

I can't figure out why it's giving me 9 gray images in a 3x3 grid instead of just one color image (original image is not gray and has RGB channels). I have spent 5 hours on this. Thanks for the help.

Here is my code

test_path = "asl_data/test/" #path to the folder
test_data = torchvision.datasets.ImageFolder(test_path, transform=torchvision.transforms.ToTensor())
def test32():
    for x, y in test_data:
        print(x.shape)
        x = x.reshape(533,800,3)
        plt.axis("off")
        plt.imshow(x)
        plt.show()
        plt.axis("off")
        plt.imshow(x[:176,:267,:])
        break
test32()

1 Answers1

1

Classic.

You reshape instead of permute.

See this thread on the crucial difference between the two.

Fix:

x = x.permute((1, 2, 0))
plt.imshow(x)

A simple visual example:

x, y = test_data[0]  # take one image
x.shape  # torch.Size([3, 223, 320])

# see the difference
fig, ax = plt.subplots(1,2)
ax[0].imshow(x.numpy().reshape(223, 320, 3))
ax[0].set_title('Wrong reshape instead of permute')

ax[1].imshow(x.permute((1,2,0)))
ax[1].set_title('correctly permuting')

enter image description here

Shai
  • 111,146
  • 38
  • 238
  • 371
  • ah, Thank you for the answer. I spent so much time reading the documentation of torchvision.dataset and torchvision.transform. I feel so stupid now. – Altaïr Ibn-La'Ahad Mar 02 '22 at 22:23