77

How do I display a PyTorch Tensor of shape (3, 224, 224) representing a 224x224 RGB image? Using plt.imshow(image) gives the error:

TypeError: Invalid dimensions for image data

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
Tom Hale
  • 40,825
  • 36
  • 187
  • 242

8 Answers8

146

Given a Tensor representing the image, use .permute() to put the channels as the last dimension:

plt.imshow(  tensor_image.permute(1, 2, 0)  )

Note: permute does not copy or allocate memory, and from_numpy() doesn't either.

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
Tom Hale
  • 40,825
  • 36
  • 187
  • 242
  • Wow thank you... This worked for me... I was trying to do tensor_image.numpy().reshape([224,224,3]) and visualize it using cv2.imshow() But i was not getting the actual image... whats going wrong here?? – Devashish Prasad Jun 04 '20 at 14:50
  • 5
    @DevashishPrasad The problem is that `reshape([224,224,3])` doesn't do the same thing that `permute(1, 2, 0)` does. The `permute` function is similar to transposing a matrix, where rows become columns and columns become rows. The `reshape` function does something totally unrelated that I don't know how to describe concisely. In short, `reshape` is the wrong function. – Tanner Swett Mar 01 '21 at 17:53
  • what is the shape of `tensor_image `? – Charlie Parker Nov 16 '22 at 23:02
  • An arguably more readable alternative is `plt.imshow(torch.einsum('cwh->whc', tensor_image))` – rusheb Dec 28 '22 at 17:27
19

As you can see matplotlib works fine even without conversion to numpy array. But PyTorch Tensors ("Image tensors") are channel first, so to use them with matplotlib you need to reshape it:

Code:

from scipy.misc import face
import matplotlib.pyplot as plt
import torch

np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)

# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)

# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)

plt.imshow(tensor_image)
plt.show()

Output:

<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
trsvchn
  • 8,033
  • 3
  • 23
  • 30
9

Given the image is loaded as described and stored in the variable image:

plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
#transforms.ToPILImage()(image).show() # Alternatively

Or as Soumith suggested:

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
iacob
  • 20,084
  • 6
  • 92
  • 119
Tom Hale
  • 40,825
  • 36
  • 187
  • 242
8

PyTorch modules processing image data expect tensors in the format C × H × W.1
Whereas PILLow and Matplotlib expect image arrays in the format H × W × C.2

You can easily convert tensors to/from this format with a TorchVision transform:

from torchvision.transforms import functional as F

F.to_pil_image(image_tensor)

Or by directly permuting the axes:

image_tensor.permute(1,2,0)

  1. PyTorch modules dealing with image data require tensors to be laid out as C × H × W : channels, height, and width, respectively.

  2. Note how we have to use permute to change the order of the axes from C × H × W to H × W × C to match what Matplotlib expects.

amirhe
  • 2,186
  • 1
  • 13
  • 27
iacob
  • 20,084
  • 6
  • 92
  • 119
4

A complete example given an image pathname img_path:

from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")

Note that transforms.* return a class, which is why the funky bracketing.

iacob
  • 20,084
  • 6
  • 92
  • 119
Tom Hale
  • 40,825
  • 36
  • 187
  • 242
2

Torch is in shape of channel,height,width need to convert it into height,width, channel so permute.

plt.imshow(white_torch.permute(1, 2, 0))

Or directly if you want

import torch
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T

!wget 'https://images.unsplash.com/photo-1553284965-83fd3e82fa5a?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxleHBsb3JlLWZlZWR8NHx8fGVufDB8fHx8&w=1000&q=80'  -O white_horse.jpg

white_torch = torchvision.io.read_image('white_horse.jpg')

T.ToPILImage()(white_torch)

enter image description here

TheExorcist
  • 1,966
  • 1
  • 19
  • 25
0

Use show_image from fastai

from fastai.vision.all import show_image

enter image description here

enter image description here

aravinda_gn
  • 1,263
  • 1
  • 11
  • 20
0

I've written a simple function to visualize the pytorch tensor using matplotlib.

import numpy as np
import matplotlib.pyplot as plt
import torch

def show(*imgs):
    '''
     input imgs can be single or multiple tensor(s), this function uses matplotlib to visualize.
     Single input example:
     show(x) gives the visualization of x, where x should be a torch.Tensor
        if x is a 4D tensor (like image batch with the size of b(atch)*c(hannel)*h(eight)*w(eight), this function splits x in batch dimension, showing b subplots in total, where each subplot displays first 3 channels (3*h*w) at most. 
        if x is a 3D tensor, this function shows first 3 channels at most (in RGB format)
        if x is a 2D tensor, it will be shown as grayscale map
     
     Multiple input example:      
     show(x,y,z) produces three windows, displaying x, y, z respectively, where x,y,z can be in any form described above.
    '''
    img_idx = 0
    for img in imgs:
        img_idx +=1
        plt.figure(img_idx)
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu()

            if img.dim()==4: # 4D tensor
                bz = img.shape[0]
                c = img.shape[1]
                if bz==1 and c==1:  # single grayscale image
                    img=img.squeeze()
                elif bz==1 and c==3: # single RGB image
                    img=img.squeeze()
                    img=img.permute(1,2,0)
                elif bz==1 and c > 3: # multiple feature maps
                    img = img[:,0:3,:,:]
                    img = img.permute(0, 2, 3, 1)[:]
                    print('warning: more than 3 channels! only channels 0,1,2 are preserved!')
                elif bz > 1 and c == 1:  # multiple grayscale images
                    img=img.squeeze()
                elif bz > 1 and c == 3:  # multiple RGB images
                    img = img.permute(0, 2, 3, 1)
                elif bz > 1 and c > 3:  # multiple feature maps
                    img = img[:,0:3,:,:]
                    img = img.permute(0, 2, 3, 1)[:]
                    print('warning: more than 3 channels! only channels 0,1,2 are preserved!')
                else:
                    raise Exception("unsupported type!  " + str(img.size()))
            elif img.dim()==3: # 3D tensor
                bz = 1
                c = img.shape[0]
                if c == 1:  # grayscale
                    img=img.squeeze()
                elif c == 3:  # RGB
                    img = img.permute(1, 2, 0)
                else:
                    raise Exception("unsupported type!  " + str(img.size()))
            elif img.dim()==2:
                pass
            else:
                raise Exception("unsupported type!  "+str(img.size()))


            img = img.numpy()  # convert to numpy
            img = img.squeeze()
            if bz ==1:
                plt.imshow(img, cmap='gray')
                # plt.colorbar()
                # plt.show()
            else:
                for idx in range(0,bz):
                    plt.subplot(int(bz**0.5),int(np.ceil(bz/int(bz**0.5))),int(idx+1))
                    plt.imshow(img[idx], cmap='gray')

        else:
            raise Exception("unsupported type:  "+str(type(img)))