4

So, I have a U-Net model and I'm feeding images of 5000x5000x3 into the model and I and getting the error above.

So here is my model.

import torch
import torch.nn as nn


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class UNeT(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',
                                    align_corners=True)
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)
        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        x = self.dconv_down4(x)
        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)
        out = self.conv_last(x)
        return out


I tried to do model(inputs.unsqueeze_(0)) but I got a different error.

Shai
  • 111,146
  • 38
  • 238
  • 371
FatJuicyBoi
  • 53
  • 1
  • 1
  • 5

1 Answers1

10

The order of dimensions in pytorch is different than what you expect. Your input tensor has shape of 4x5000x5000x3 which you interpret as a batch of size 4, with images of 5000x5000 pixels, each pixel has 3 channels. That is, your dimensions are batch-height-width-channel.

However, pytorch expects tensor dimensions to be in a different order: batch-channel-height-width. That is, the channel dimension should precede the width and height spatial dimensions.

You need to permute the dimensions of your input tensor to solve your problem:

model(inputs.permute(0, 3, 1, 2))

For more information, see the documentation of nn.Conv2d.

Shai
  • 111,146
  • 38
  • 238
  • 371
  • Late to the conversation but why can't you use `torch.reshape` instead of `permute`? Not sure I understand what permute does differently – turnip Sep 22 '22 at 13:49
  • @turnip there's a HUGE difference between `reshape`/`view` and `permute`. Read [this answer](https://stackoverflow.com/a/71329851/1714410) to learn more. – Shai Sep 28 '22 at 05:10