Implementing the paper - `A Simple Baseline for Fast and Accurate Depth Estimation on Mobile Devices`
implementing a paper
- About the paper
- The implementation
- The dataset
- Training our model
- Training resnet101 backbone to compare
- Let's see if knowledge distilation helps
Resources
The paper proposes a simple encoder-decoder based network for fast and accurate depth estimation on mobile devices. Depth estimation is a computer vision task that is relevant to robotics, autonomous driving cars, scene understand- ing, and 3D reconstructions.
Most of SOTA on depth estimation is based around CNN/transformer models. These models can also be a challenge to deploy on mobile devices due to huge computation requirement.
In this paper, a mobilenet based encoder-decoder architecture is proposed that has a good compromise between computation requirement and model performance. To further enhance the performance, knowledge distilaition is also used.
For this implementation, the encoder architecture uses the timm
library.
'''
Some codes are based off https://gist.github.com/rwightman/f8b24f4e6f5504aba03e999e02460d31
'''
class Conv2dBnAct(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d
):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
self.bn = norm_layer(out_channels)
self.act = act_layer(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
return x
class FeatureFusionModule(nn.Module):
def __init__(self,
enc_in_channels,
enc_out_channels,
dec_in_channels,
out_channels
):
super().__init__()
#encoderoutput
self.enc_conv1 = nn.Conv2d(enc_in_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')
self.enc_up = nn.ConvTranspose2d(enc_out_channels, enc_out_channels, kernel_size=1)
self.enc_dconv = DepthwiseSeparableConv(enc_out_channels, enc_out_channels)
self.enc_conv2 = nn.Conv2d(enc_out_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')
#decoderoutput
self.dec_dconv = DepthwiseSeparableConv(enc_out_channels+dec_in_channels, enc_out_channels+dec_in_channels)
self.dec_conv1 = nn.Conv2d(enc_out_channels+dec_in_channels, out_channels, kernel_size=1, stride=1, padding='same')
def forward(self, enc_x, dec_x):
enc_x = self.enc_conv1(enc_x)
enc_x = self.enc_up(enc_x)
enc_x = self.enc_dconv(enc_x)
enc_x = self.enc_conv2(enc_x)
x = torch.cat([enc_x, dec_x], dim=1)
dec_x = self.dec_dconv(x)
dec_x = self.dec_conv1(dec_x)
return dec_x
class DecoderBlock(nn.Module):
def __init__(self,
enc_channels,
dec_prev_channels,
dec_channels,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
ffm=True,
):
super().__init__()
conv_args = dict(kernel_size=3, padding=1, act_layer=act_layer)
self.ffm = ffm
if ffm:
self.ffm = FeatureFusionModule(enc_channels, enc_channels, dec_prev_channels, dec_channels)
self.conv1 = Conv2dBnAct(enc_channels, dec_channels, norm_layer=norm_layer, **conv_args)
self.conv2 = Conv2dBnAct(dec_channels, dec_channels, norm_layer=norm_layer, **conv_args)
def forward(self, x_enc, x_dec):
if self.ffm:
x = self.ffm(x_enc, x_dec)
x = F.interpolate(x_enc, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.conv2(x)
return x
class UnetDecoder(nn.Module):
def __init__(self,
encoder_channels,
decoder_channels=(256, 128, 64, 32, 16),
final_channels=3,
norm_layer=nn.BatchNorm2d,
):
super().__init__()
self.decoders = nn.ModuleList()
for i, (e_ch, d_ch) in enumerate(zip(encoder_channels, decoder_channels)):
if i== 0:
self.decoders.append(DecoderBlock(enc_channels=e_ch,
dec_prev_channels=None,
dec_channels=d_ch,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
ffm=False,
))
else:
self.decoders.append(DecoderBlock(enc_channels=e_ch,
dec_prev_channels=decoder_channels[i-1],
dec_channels=d_ch,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
ffm=True,
))
self.final_conv = nn.Conv2d(decoder_channels[-1], final_channels, kernel_size=(1, 1))
self._init_weight()
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
enc_outs_r = x
dec_out = None
for i, each in enumerate(self.decoders):
dec_out = each(enc_outs_r[i], dec_out)
x = self.final_conv(dec_out)
return x
class Unet(nn.Module):
def __init__(self,
backbone='resnet50',
backbone_kwargs=None,
backbone_indices=None,
decoder_use_batchnorm=True,
decoder_channels=(256, 128, 64, 32, 16),
in_chans=3,
num_classes=3,
norm_layer=nn.BatchNorm2d,
pretrained=True,
):
super().__init__()
backbone_kwargs = backbone_kwargs or {}
# NOTE some models need different backbone indices specified based on the alignment of features
# and some models won't have a full enough range of feature strides to work properly.
encoder = create_model(
backbone, features_only=True, out_indices=backbone_indices, in_chans=in_chans,
pretrained=pretrained, **backbone_kwargs)
encoder_channels = encoder.feature_info.channels()[::-1]
self.encoder = encoder
self.decoder = UnetDecoder(
encoder_channels=encoder_channels,
decoder_channels=decoder_channels,
final_channels=num_classes,
norm_layer=norm_layer,
)
def forward(self, x: torch.Tensor):
x = self.encoder(x)
x.reverse()
x = self.decoder(x)
return x
To test this, we will use the CAMVID
dataset. It has 32
classes. The codes for the dataloader is from Zach's WalkWithFastai
tutorial.
path = untar_data(URLs.CAMVID)
path.ls()
codes = np.loadtxt(path/'codes.txt', dtype=str)
codes, len(codes)
fnames = get_image_files(path/"images")
def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}"
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Void']
def acc_camvid(inp, targ):
targ = targ.squeeze(1)
mask = targ != void_code
return np.mean(inp.argmax(dim=1)[mask].cpu().numpy()==targ[mask].cpu().numpy())
dls = SegmentationDataLoaders.from_label_func(path,
bs=8,
fnames = fnames,
label_func = label_func,
codes = codes,
item_tfms=Resize((128, 160)))
model = Unet('mobilenetv3_rw',
num_classes=32)
learn = Learner(dls,
model,
metrics=acc_camvid)
We will use the default CrossEntropy
loss and the Adam
optimisers.
learn.loss_func
learn.opt_func
learn.fine_tune(25, 1e-3, freeze_epochs=3)
We are getting around 66% accuracy. Let's see how a large ResNet
does?
learn_r101 = unet_learner(dls,
resnet101,
metrics=acc_camvid)
learn_r101.fine_tune(5, 1e-3, freeze_epochs=3)
ours = total_params(learn)[0]
r101 = total_params(learn_r101)[0]
print(f"The resnet101 is {r101/ours:.2f} larger than the implemented model")
Knowledge distillation is a training technique that is used to distill/condense "knowledge" from large models on to smaller models. This is especially useful in situation that requires models to be deployed at the edge where computation is brought to the point at which data is produced. Below is an illustration of knowledge distillation technique.
class DistillationLoss(nn.Module):
def __init__(self):
super(DistillationLoss, self).__init__()
self.distillation_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self,
student_preds,
teacher_preds,
acutal_target,
T,
alpha
):
return self.distillation_loss(F.softmax(student_preds / T, dim=1).reshape(-1),
F.softmax(teacher_preds / T, dim=1).reshape(-1))
class KnowledgeDistillation(Callback):
def __init__(self,
teacher:Learner,
T:float=20.,
a:float=0.7):
super(KnowledgeDistillation, self).__init__()
self.teacher = teacher
self.teacher.eval()
self.T, self.a = T, a
self.distillation_loss = DistillationLoss()
def after_loss(self):
teacher_preds = self.teacher.model(self.learn.xb[0])
student_loss = self.learn.loss_grad * self.a
distillation_loss = self.distillation_loss(self.learn.pred, # Student preds
teacher_preds, # Teacher preds
self.learn.yb, # Ground truth
self.T,
self.a) * (1 - self.a)
self.learn.loss_grad = student_loss + distillation_loss
We will use the resnet101 model as the teacher network.
model = Unet('mobilenetv3_rw',
num_classes=32)
student_learn = Learner(dls,
model,
metrics=acc_camvid,
cbs=[KnowledgeDistillation(teacher=learn_r101)])
student_learn.fine_tune(25, 1e-3, freeze_epochs=3)
The results show about the same accuracy but generally knowledge distillation requires longer training.