Learning to resize images for CV tasks
Learning to Resize Images for Vision Transformer
Generally, for CV tasks, we use resize
that uses bilinear and bicubic interpolation to resize images. Resize is necessary for the model to be efficient but, at times, come at a cost of model accuracy.
In this paper, a new technique was introduced for resizing images for CV tasks. One in which resize takes place by a learning
mechanism instead of hard-coding
the resize.
Source: Learning to Resize Images for Computer Vision Tasks
Source: Learning to Resize Images for Computer Vision Tasks
As we can see the proposed method reduces the error-rate for different architecture in classification tasks.
Source: Learning to Resize Images for Computer Vision Tasks
The above image captures the model architecture as well as the process.
In this blog, we will explore this technique using fastai and timm's Swin Transformed model
We will use the imagenette
dataset by fastai
lbl_dict = dict(
n01440764='tench',
n02102040='English springer',
n02979186='cassette player',
n03000684='chain saw',
n03028079='church',
n03394916='French horn',
n03417042='garbage truck',
n03425413='gas pump',
n03445777='golf ball',
n03888257='parachute'
)
def label_func(fname):
return lbl_dict[parent_label(fname)]
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
get_y = label_func,
splitter = GrandparentSplitter(valid_name='val'),
item_tfms = Resize(224),
batch_tfms=Normalize.from_stats(*imagenet_stats))
dls = dblock.dataloaders(path, bs=8)
dls.show_batch()
model = create_model('swin_large_patch4_window7_224', pretrained=True, num_classes=dls.c)
learn = Learner(dls,
model,
cbs=[GradientAccumulation(12), GradientClip()],
metrics=[accuracy, error_rate])
learn.fit_one_cycle(5)
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
get_y = label_func,
splitter = GrandparentSplitter(valid_name='val'),
item_tfms = Resize(512),
batch_tfms=Normalize.from_stats(*imagenet_stats))
dls = dblock.dataloaders(path, bs=4)
class ResBlock(nn.Module):
def __init__(self,num_channels=16):
super(ResBlock,self).__init__()
self.conv1 = nn.Conv2d(num_channels,num_channels,kernel_size=3,stride=1,padding=1)
self.bn1 = nn.BatchNorm2d(num_channels)
self.leakyrelu = nn.LeakyReLU(negative_slope=0.2,inplace=True)
self.conv2 = nn.Conv2d(num_channels,num_channels,kernel_size=3,stride=1,padding=1)
self.bn2 = nn.BatchNorm2d(num_channels)
def forward(self,x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.leakyrelu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
return out
def make_block(r,n):
residual = []
for i in range(r):
block = ResBlock(num_channels=n)
residual.append(block)
return nn.Sequential(*residual)
class ResizingNetwork(nn.Module):
def __init__(self, img_size, in_chans = 3, r=1, n=16):
super(ResizingNetwork, self).__init__()
self.img_size = img_size
self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=n, kernel_size=7,stride=1,padding=3)
self.leakyrelu1 = nn.LeakyReLU(negative_slope=0.2,inplace=True)
self.conv2 = nn.Conv2d(n,n,kernel_size=1,stride=1)
self.leakyrelu2 = nn.LeakyReLU(negative_slope=0.2,inplace=True)
self.bn1 = nn.BatchNorm2d(n)
self.resblock = make_block(r,n)
self.conv3 = nn.Conv2d(n,n,kernel_size=3,stride=1,padding=1)
self.bn2 = nn.BatchNorm2d(n)
self.conv4 = nn.Conv2d(n,out_channels=in_chans,kernel_size=7,stride=1,padding=3)
def forward(self, x):
residual = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
out = self.conv1(x)
out = self.leakyrelu1(out)
out = self.conv2(out)
out = self.leakyrelu2(out)
out = self.bn1(out)
out_residual = F.interpolate(out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
out = self.resblock(out_residual)
out = self.conv3(out)
out = self.bn2(out)
out += out_residual
out = self.conv4(out)
out += residual
return out
class LRSwinModel(Module):
def __init__(self, img_size=224):
self.resizenet = ResizingNetwork(img_size)
self.swin = create_model('swin_large_patch4_window7_224', pretrained=True, num_classes=dls.c)
def forward(self, x):
x = self.resizenet(x)
x = self.swin(x)
return x
model = LRSwinModel()
learn = Learner(dls,
model,
cbs=[GradientAccumulation(24), GradientClip()],
metrics=[accuracy, error_rate])
learn.fit_one_cycle(5)