Generally, for CV tasks, we use resize that uses bilinear and bicubic interpolation to resize images. Resize is necessary for the model to be efficient but, at times, come at a cost of model accuracy.

In this paper, a new technique was introduced for resizing images for CV tasks. One in which resize takes place by a learning mechanism instead of hard-coding the resize.

image.png

Source: Learning to Resize Images for Computer Vision Tasks

image.png

Source: Learning to Resize Images for Computer Vision Tasks

As we can see the proposed method reduces the error-rate for different architecture in classification tasks.

image.png

Source: Learning to Resize Images for Computer Vision Tasks

The above image captures the model architecture as well as the process.

In this blog, we will explore this technique using fastai and timm's Swin Transformed model

The Dataset

We will use the imagenette dataset by fastai

lbl_dict = dict(
    n01440764='tench',
    n02102040='English springer',
    n02979186='cassette player',
    n03000684='chain saw',
    n03028079='church',
    n03394916='French horn',
    n03417042='garbage truck',
    n03425413='gas pump',
    n03445777='golf ball',
    n03888257='parachute'
)

def label_func(fname):
    return lbl_dict[parent_label(fname)]
dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   get_items = get_image_files,
                   get_y     = label_func,
                   splitter  = GrandparentSplitter(valid_name='val'),
                   item_tfms = Resize(224), 
                   batch_tfms=Normalize.from_stats(*imagenet_stats))

dls = dblock.dataloaders(path, bs=8)
dls.show_batch()

Training with normal resizer

model = create_model('swin_large_patch4_window7_224', pretrained=True, num_classes=dls.c)
learn = Learner(dls,
                model,
                cbs=[GradientAccumulation(12), GradientClip()],
                metrics=[accuracy, error_rate])
learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy error_rate time
0 0.829841 0.823315 0.780382 0.219618 29:55
1 0.657223 0.741026 0.798217 0.201783 29:56
2 0.416582 0.399281 0.883567 0.116433 29:55
3 0.249072 0.257628 0.924841 0.075159 30:06
4 0.061879 0.187841 0.951592 0.048408 30:06

Training with learned resizer

dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   get_items = get_image_files,
                   get_y     = label_func,
                   splitter  = GrandparentSplitter(valid_name='val'),
                   item_tfms = Resize(512), 
                   batch_tfms=Normalize.from_stats(*imagenet_stats))

dls = dblock.dataloaders(path, bs=4)
class ResBlock(nn.Module):
    def __init__(self,num_channels=16):
        super(ResBlock,self).__init__()
        
        self.conv1 = nn.Conv2d(num_channels,num_channels,kernel_size=3,stride=1,padding=1)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.leakyrelu = nn.LeakyReLU(negative_slope=0.2,inplace=True)
        
        self.conv2 = nn.Conv2d(num_channels,num_channels,kernel_size=3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(num_channels)
    
    def forward(self,x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        
        out = self.leakyrelu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        
        return out

def make_block(r,n):
    residual = []
    
    for i in range(r):
        block = ResBlock(num_channels=n)
        residual.append(block)
    
    return nn.Sequential(*residual)

class ResizingNetwork(nn.Module):
    def __init__(self, img_size, in_chans = 3, r=1, n=16):
        super(ResizingNetwork, self).__init__()

        self.img_size = img_size
        
        self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=n, kernel_size=7,stride=1,padding=3)
        self.leakyrelu1 = nn.LeakyReLU(negative_slope=0.2,inplace=True)
        
        self.conv2 = nn.Conv2d(n,n,kernel_size=1,stride=1)
        self.leakyrelu2 = nn.LeakyReLU(negative_slope=0.2,inplace=True)
        self.bn1 = nn.BatchNorm2d(n)
                
        self.resblock = make_block(r,n)        
        
        self.conv3 = nn.Conv2d(n,n,kernel_size=3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(n)
        
        self.conv4 = nn.Conv2d(n,out_channels=in_chans,kernel_size=7,stride=1,padding=3)
        
    def forward(self, x):
    
        residual = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)

        out = self.conv1(x)
        out = self.leakyrelu1(out)

        out = self.conv2(out)
        out = self.leakyrelu2(out)
        out = self.bn1(out)

        out_residual = F.interpolate(out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)

        out = self.resblock(out_residual)

        out = self.conv3(out)
        out = self.bn2(out)
        out += out_residual

        out = self.conv4(out)
        out += residual            

        return out
class LRSwinModel(Module):
    def __init__(self, img_size=224):
        self.resizenet = ResizingNetwork(img_size)
        self.swin = create_model('swin_large_patch4_window7_224', pretrained=True, num_classes=dls.c)
    def forward(self, x):
        x = self.resizenet(x)
        x = self.swin(x)
        return x
model = LRSwinModel()
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth" to /root/.cache/torch/hub/checkpoints/swin_large_patch4_window7_224_22kto1k.pth
learn = Learner(dls,
                model,
                cbs=[GradientAccumulation(24), GradientClip()],
                metrics=[accuracy, error_rate])
learn.fit_one_cycle(5)
20.00% [1/5 38:02<2:32:08]
epoch train_loss valid_loss accuracy error_rate time
0 0.675374 0.573788 0.824713 0.175287 38:02

24.33% [576/2367 08:01<24:58 0.6032]
</div> </div>
epoch train_loss valid_loss accuracy error_rate time
0 0.675374 0.573788 0.824713 0.175287 38:02
1 0.655572 0.450714 0.867261 0.132739 38:01
2 0.330661 0.315596 0.907261 0.092739 38:02
3 0.115059 0.159794 0.953885 0.046115 37:58
4 0.030658 0.142762 0.961274 0.038726 38:03
</div> </div> </div>

As we can see from the above example, using learned resizer results in better accuracy/error rate and helps to reach this at a faster rate.

</div>