Self-supervised learning and soft-labeling for Image Classification
Applying self-supervised learning and soft-labeling to tackle Plant Pathology Kaggle competition
In this blog, we will learn/apply self-supervised learning and soft-label/pseudo-labelling technique to tackle the plant pathology Kaggle competition. I was introduced to soft-labeling when I came across this amazing blog by Isaac Flath. About the same time, I came across Jeremy's blog on self-supervised learning (SSL). Both the techniques were interesting and I feel going forward these techniques are going to be important techniques in computer vision. Hence, I decided to blog my learnings of these intriguing techniques.
A word on the dataset
For this blog, we will be using the Plant Pathology dataset. The task at hand is to classify images of different plant diseases.
There are some challenges with the dataset
- Imbalanced data - one of the class contains very few samples
- Mislabeled data - the other issue with this dataset is that the labels are noisy/mislabeled. The winner of this competition used soft-labelling in the winning model
We will do this in two parts
-
Self-supervised learning
- What is self-supervised learning
- Different types of self-supervised learning
- Applying rotation-based learning (rotnet)
- Applying SimCLR learning
-
Soft-labeling | pseudo-laebling
- Generating pseudo-label using transfer learning from rotnet model, SimCLR model
- Final task (downstream task) to predict plant diseases using transfer learning (using the model trained in part 1) and soft-labeling/progressive pseudo-labelling
Supervised learning has brought tremendous success to the field of computer vision. Although supervised learning requires huge amount of data, by using transfer learning one could reduce the requirement of data by about 1000x.
Largely, models for transfer learning are trained using the ImageNet dataset which were trained on 1.4 million images of 1000 different classes. Sometimes, for some domain, such as medical images, transfer learning from ImageNet might not work that well. Given this constraint, and the fact that generating labels for large amount of data in this domain is costly, self-supervised learning (SSL) has proven to be effective. In SSL, unlabeled data is trained in a supervised manner.
So what is self-supervised learning? SSL make use of the labels that are naturally part of the input data. SSL has been widely used in NLP tasks. Training a language model often involves predicting the next word of a sentence. The label for the language model, the next word, is a natural part of the input data. While learning to predict the next word, the model must have learnt a bit about the nature of language. Now, such training are common in most NLP tasks.
In computer vision, SSL starts with "pretext tasks". In "pretext task", we train a model from scratch using labels that comes naturally with the input data. Once pretrained, fine-tuning can be carried out on the "downstream tasks".
Further Readings
https://www.fast.ai/2020/01/13/self_supervised/
https://lilianweng.github.io/lil-log/2019/11/10/self-supervised-learning.html
Different techniques of SSL can be categorized into three
-
Pretext task based
- Relative positioning
- Colorization
- Rotation
- Multiple pretext
- Jigsaw puzzle
-
Generative model based
- Autoencoders
- Split brain autoencoder
- Neural scene representation
- Context encoder
- Semantic inpainting
- BiGAN
-
Discriminative based contrastive learning
- SimCLR
- SimCLR2
- MoCo
- MoCo v2
- BYOL
- SwAV
SSL and different techniques were well explained in this lecture series by Anuj Shah. Taking a look a the SOTA page for SSL on PapersWithCode. The discriminative based contrastive learning are leading the pack.
In this pretext task, we rotate the image by certain degrees (0, 90, 180 and 270) and train the model to predict which degree of rotation was applied. While learning to classify the degree of rotation, the model learns many semantic concepts.
(Image source: Gidaris et al. 2018) https://arxiv.org/abs/1803.07728
Now, let's take a look how we can apply this pretext task using fastai. For a more detailed explanation, please refer to this amazing blog by Amar Saini.
from fastai.vision.all import *
from fastai.vision.learner import _update_first_layer
import timm
import torchvision
path = Path('/content/drive/MyDrive/colab_notebooks/fastai/plant_pathology/data')
path_img = path/'images'
train = pd.read_csv(path/'train.csv')
train.head(5)
We will use seresnext50_32x4d
from timm model as the architecture of choice.
def create_timm_body(arch:str, pretrained=False, cut=None, n_in=3):
"Creates a body from any model in the `timm` library."
model = timm.create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
_update_first_layer(model, n_in, pretrained)
if cut is None:
ll = list(enumerate(model.children()))
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
elif callable(cut): return cut(model)
else: raise NamedError("cut must be either integer or function")
arch = create_timm_body('seresnext50_32x4d')
nf = num_features_model(arch)
head = create_head(nf, 4, concat_pool=True)
net = nn.Sequential(arch, head)
tensorToImage = torchvision.transforms.ToPILImage()
imageToTensor = torchvision.transforms.ToTensor()
Once we have selected the encoder/base architecture, it's time to build our PyTorch style dataset. There is nothing fancy going on here except each training example is rotated randomly by 0, 90, 180 or 270 degrees. The degree of rotation will also be used as the label.
class Custom_Dataset_PP():
# Codes from Amar Saini's (__Epoching__) blog
def __init__(self, fns):
self.fns = fns
def __len__(self):
return len(self.fns)
def __getitem__(self, idx):
# 4 classes for rotation
degrees = [0, 90, 180, 270]
rand_choice = random.randint(0, len(degrees)-1)
img = PILImage.create(self.fns[idx])
img = img.resize((256, 256))
img = img.rotate(degrees[rand_choice])
img = imageToTensor(img)
return img, torch.tensor(rand_choice).long()
def show_batch(self, n=3):
fig, axs = plt.subplots(n, n, figsize=(12,10))
fig.tight_layout()
for i in range(n):
for j in range(n):
rand_idx = random.randint(0, len(self)-1)
img, label = self.__getitem__(rand_idx)
axs[i, j].imshow(tensorToImage(img), cmap='gray')
axs[i, j].set_title('Label: {0} ({1} Degrees)'.format(label.item(), label.item()*90))
axs[i, j].axis('off')
ds = Custom_Dataset_PP(path_img.ls())
As we can see from the show_batch
, the rotation based pretext task might not be good for this dataset as leaves can exist in any orientations. But our objective is to learn hence this is still a good practice.
ds.show_batch()
From here, we use standard fastai practice to train the model.
split = int(len(path_img.ls())*0.8)
train_fns = path_img.ls()[:split]
valid_fns = path_img.ls()[split:]
train_ds = Custom_Dataset_PP(train_fns)
valid_ds = Custom_Dataset_PP(valid_fns)
dls = DataLoaders.from_dsets(train_ds, valid_ds).cuda()
learn = Learner(dls, net, loss_func=CrossEntropyLossFlat(), splitter=default_split, metrics=accuracy)
learn.fit_one_cycle(25, 1e-4)
We will save the weights of the encoader.
torch.save(learn.model[0], f'{path}/seresnext50_32x4d_rotnetencoader.pth')
As we saw earlier in PapersWithCode SOTA page for SSL, discriminative based contrastive learning are among the most successful SSL techniques. SimCLR is one among these techniques that was introduced by Google Research.
Before we go further, let’s take a look at the definition of contrastive learning?
What is contrastive learning? It is a SSL technique to learn features of a dataset without labels by teaching the model which data points are similar and which are different.
The SimCLR framework proposes four key components for contrastive learning
-
Data Augmentation
Here, an image is randomly transformed into two correlated views. This forms a positive pair. SimCLR sequentially applies the following augmentations - random cropping, followed by resize back to the original size, random color distortions and random Gaussian blur. Two views coming from the same image would form a positive pair while two views coming from different image would form a negative pair.
-
A neural network base encoder
A model backbone that extracts the feature maps. In our case, we will be using
seresnext50_32x4d
from timm -
Neural network projection head
This maps the feature maps into a vector space representation where the contrastive loss will be applied. Multilayer perceptron with single linear layer was used in the original paper.
-
Contrastive loss
A loss function designed to evaluate how good a job the network is at picking up positive pairs and negative pairs.
Let's apply SimCLR in fastai using this self-supervised library by Kerem Turgutlu.
from self_supervised.simclr import *
First, lets make a general dataloader as per usual.
train['labels'] = train.iloc[:, 1:].idxmax(axis=1)
fns = train.image_id.values
def label_fn(fn):
return train[train['image_id'] == fn]['labels'].values[0]
label_fn(fns[0])
sz = 256
ds = Datasets(fns, [[lambda x: f'{path_img}/{x}.jpg', PILImage.create], [label_fn, Categorize()]], splits=RandomSplitter()(fns))
dls = ds.dataloaders(bs=32, after_item=[ToTensor(), Resize(sz), IntToFloatTensor()])
dls.show_batch()
Next, we will make the model which would include an encoader and a projection head. For this we will make use of self-supervised library with some changes. The encoader will come from timm and then attach the projection head to the encoader.
def create_timm_body(arch:str, pretrained=False, cut=None, n_in=3):
"Creates a body from any model in the `timm` library."
model = timm.create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
_update_first_layer(model, n_in, pretrained)
if cut is None:
ll = list(enumerate(model.children()))
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
elif callable(cut): return cut(model)
else: raise NamedError("cut must be either integer or function")
class MLP(Module):
"MLP module as described in paper"
def __init__(self, dim, concat_pool=True, projection_size=128, hidden_size=256):
self.pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)
self.flatten = Flatten()
self.net = nn.Sequential(
nn.Linear(dim, hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)
def forward(self, x):
x = self.pool(x)
x = self.flatten(x)
x = self.net(x)
return x
def create_simclr_model(arch=resnet50, n_in=3, pretrained=False, cut=None, concat_pool=True,
hidden_size=256, projection_size=128):
"Create SimCLR from a given arch"
encoder = create_timm_body(arch, pretrained, cut, n_in)
with torch.no_grad(): representation = encoder(torch.randn((2, n_in, 128, 128)))
projector = MLP(representation.size(1)*2, projection_size, hidden_size=hidden_size)
apply_init(projector)
return SimCLRModel(encoder, projector)
model = create_simclr_model('seresnext50_32x4d', pretrained=False)
Once we have the dataloader and the model, we will make a learner. The learner use a SimCLRLoss
and SimCLR
callback function. We will look at these codes to understand what is going on.
learn = Learner(dls, model,
SimCLRLoss(temp=0.5),
cbs=[SimCLR(size=128, color=False, stats=None)])
Below we have the codes for the SimCLR
callback. The callback upon initiation makes two augmentation function - aug1
and aug2
. These will be used to make the two views of the same input image. As you can see from the get_aug_pipe
function the augmentation uses RandomResizedCrop
, RandomHorizontalFlip
, ColorJitter
, and RandomGrayscale
. The callback before_batch
makes the two images, concats them and use that as the inputs. The labels are also changed. Assuming the initial batch_size of 5, the labels would be [5, 6, 7, 8, 9, 0, 1, 2, 3, 4]
. The first image is positive pair of 5th image, the second image is positive pair of 6th image, and so on.
def get_aug_pipe(size, stats=imagenet_stats, s=.6, color=True, xtra_tfms=[]):
"SimCLR augmentations"
tfms = []
tfms += [kornia.augmentation.RandomResizedCrop((size, size), scale=(0.2, 1.0), ratio=(3/4, 4/3))]
tfms += [kornia.augmentation.RandomHorizontalFlip()]
if color: tfms += [kornia.augmentation.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)]
if color: tfms += [kornia.augmentation.RandomGrayscale(p=0.2)]
tfms += xtra_tfms
if stats is not None: tfms += [Normalize.from_stats(*stats)]
pipe = Pipeline(tfms)
pipe.split_idx = 0
return pipe
class SimCLR(Callback):
"SimCLR callback"
def __init__(self, size=256, **aug_kwargs):
self.aug1 = get_aug_pipe(size, **aug_kwargs)
self.aug2 = get_aug_pipe(size, **aug_kwargs)
def before_batch(self):
xi,xj = self.aug1(self.x), self.aug2(self.x)
self.learn.xb = (torch.cat([xi, xj]),)
bs = self.learn.xb[0].shape[0]
self.learn.yb = (torch.arange(bs, device=self.dls.device).roll(bs//2),)
def show_one(self):
xb = TensorImage(self.learn.xb[0])
bs = len(xb)//2
i = np.random.choice(bs)
xb = self.aug1.decode(xb.to('cpu').clone()).clamp(0,1)
images = [xb[i], xb[bs+i]]
show_images(images)
b = dls.one_batch()
learn._split(b)
learn('before_batch')
The modified image inputs; the initial batch of 32 is doubled.
learn.xb[0].shape
Let's take a look at the modified labels.
learn.yb[0]
Now that we have taken a look at the callback, lets take a look at the loss function.
class SimCLRLoss(Module):
"SimCLR loss function"
def __init__(self, temp=0.1):
self.temp = temp
def forward(self, inp, targ):
bs,feat = inp.shape
csim = F.cosine_similarity(inp, inp.unsqueeze(dim=1), dim=-1)/self.temp
csim = remove_diag(csim)
targ = remove_diag(torch.eye(targ.shape[0], device=inp.device)[targ]).nonzero()[:,-1]
return F.cross_entropy(csim, targ)
The loss function calculates the cosine_similarity
of the vector representation from the projection head and then divides it by the hyperparameter temp
. We remove the diagonals from the cosine similarity and the targets. Then we calculate the cross_entropy.
Let's fit_one_cycle.
learn.fit_one_cycle(25, 1e-4)
Similar to out rotnet, we will save the weights of the encoader.
torch.save(learn.model.encoder, f'{path}/seresnext50_32x4d_simclrencoader.pth')
Now that we have pre-trained our network using SSL techniques, we will move on to the downstream task. For our downstream task, we will also use soft-labeling to help our training.
Soft-labeling is very helpful in cases where the labels are noisy. The dataset we have been using, Plant Pathology, suffers from noisy labels. The winning solution for this competition used soft-labeling. I came across soft-labeling through Isaac's blog. For more comprehensive application of soft-labeling, please refer to Isaac's blog.
We will do this part in two steps
Step 1: Generating soft-labels
Step 2: Applying soft-labels in the downstream task
from sklearn.model_selection import StratifiedKFold
So to apply soft-labeling, we will have to generate the pseudo-labels. We will use SSL trained models to generate our pseudo-labels. The following is the steps to make pseudo-labels.
- Create kfold. In our case, we will use k=2
- Train the model using SSL trained weights and fit on the training data with the given labels
- After sufficient training, we will generate prediction on the valid set. The prediction will be used as the pseudo-labels
- Repeat step 2 and 3 for all different k valid sets
train.head(3)
train['labels'] = train.iloc[:, 1:].idxmax(1)
N_FOLDS = 2
train['fold'] = -1
strat_kfold = StratifiedKFold(n_splits=N_FOLDS, random_state=42, shuffle=True)
for i, (_, test_index) in enumerate(strat_kfold.split(train.image_id.values, train['labels'].values)):
train.iloc[test_index, -1] = i
train['fold'] = train['fold'].astype('int')
def get_dls(df, size, fold, bs):
batch_tfms = [Normalize.from_stats(*imagenet_stats)]
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
splitter=IndexSplitter(df.loc[df.fold==fold].index),
getters=[
ColReader('image_id', pref=path/'images', suff='.jpg'),
ColReader('labels')
],
item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)],
batch_tfms=batch_tfms)
return dblock.dataloaders(df, bs=bs)
def get_model(df, rotnet=True):
if rotnet:
arch = create_timm_body('seresnext50_32x4d')
torch.load(f'{path}/seresnext50_32x4d_rotnetencoader.pth')
head = create_head(num_features_model(arch), df['labels'].nunique())
apply_init(head)
model = nn.Sequential(arch, head)
else:
arch = create_timm_body('seresnext50_32x4d')
torch.load(f'{path}/seresnext50_32x4d_simclrencoader.pth')
head = create_head(num_features_model(arch), df['labels'].nunique())
apply_init(head)
model = nn.Sequential(arch, head)
return model
splits, preds, targs, preds_c, = [],[],[],[]
items = pd.DataFrame(columns = train.columns)
for i in range(N_FOLDS):
dls = get_dls(train, 228, i, bs=16)
model = get_model(train, rotnet=False)
learn = Learner(dls, model, metrics=[accuracy, RocAuc()])
learn.fine_tune(15, reset_opt=True)
# store predictions
p, t, c = learn.get_preds(ds_idx=1, with_decoded=True)
preds.append(p); targs.append(t); preds_c.append(c);
items = pd.concat([items, dls.valid.items])
imgs = L(o for o in items.image_id.values)
y_true = L(o for o in items.labels.values)
y_targ = L(dls.vocab[o] for o in torch.cat(targs))
y_pred = L(dls.vocab[o] for o in torch.cat(preds_c))
p_max = torch.cat(preds).max(dim=1)[0]
res = pd.DataFrame({'imgs':imgs,'y_true':y_true,'y_pred':y_pred}).set_index('imgs')
print(res.shape)
print(train.shape)
res.sample(5)
res.to_csv(f'{path}/train_simclr_sl.csv')
For our final task, we will train using the SSL trained models and use the pseudolebels we generated to 'reduce' the impact of nosiy labels.
# dataframe with softlabels
rotnet_df = pd.read_csv(f'{path}/train_rotnet_sl.csv')
simclr_df = pd.read_csv(f'{path}/train_simclr_sl.csv')
rotnet_df.columns = ['image_id', 'labels', 'softlabels']
simclr_df.columns = ['image_id', 'labels', 'softlabels']
from sklearn.model_selection import StratifiedKFold
import gc
N_FOLDS = 3
rotnet_df['fold'] = -1
strat_kfold = StratifiedKFold(n_splits=N_FOLDS, random_state=42, shuffle=True)
for i, (_, test_index) in enumerate(strat_kfold.split(rotnet_df.image_id.values, rotnet_df['labels'].values)):
rotnet_df.iloc[test_index, -1] = i
rotnet_df['fold'] = rotnet_df['fold'].astype('int')
N_FOLDS = 3
simclr_df['fold'] = -1
strat_kfold = StratifiedKFold(n_splits=N_FOLDS, random_state=42, shuffle=True)
for i, (_, test_index) in enumerate(strat_kfold.split(simclr_df.image_id.values, simclr_df['labels'].values)):
simclr_df.iloc[test_index, -1] = i
simclr_df['fold'] = simclr_df['fold'].astype('int')
class SoftLabelCB(Callback):
def __init__(self, df_preds, y_true_weight = 0.5):
'''df_preds is a pandas dataframe where index is image paths
Must have y_true and y_pred one hot encoded columns (ie y_true_0, y_true_1)
'''
self.y_true_weight = y_true_weight
self.y_pred_weight = 1 - y_true_weight
self.df = pd.get_dummies(df_preds, columns=['labels', 'softlabels'])
def before_train(self):
if type(self.dl.items)==type(pd.DataFrame()): self.idx_list = L(o for o in self.dl.items.index.values)
if is_listy(self.dl.items): self.imgs_list = L(self.dl.items)
def before_validate(self):
if type(self.dl.items)==type(pd.DataFrame()): self.idx_list = L(o for o in self.dl.items.index.values)
if is_listy(self.dl.items): self.imgs_list = L(self.dl.items)
def before_batch(self):
# get the images' names for the current batch
idx = self.idx_list[self.dl._DataLoader__idxs[self.iter*self.dl.bs:self.iter*self.dl.bs+self.dl.bs]]
# get soft labels
df = self.df
soft_labels = df.loc[idx,df.columns.str.startswith('labels')].values
if self.training:
soft_labels = soft_labels*self.y_true_weight + df.loc[idx,df.columns.str.startswith('softlabels')].values*self.y_pred_weight
self.learn.yb = (Tensor(soft_labels).cuda(),)
class CrossEntropyLossOneHot(nn.Module):
def __init__(self):
super(CrossEntropyLossOneHot, self).__init__()
self.log_softmax = nn.LogSoftmax(dim=-1)
def forward(self, preds, labels):
return torch.mean(torch.sum(-labels * self.log_softmax(preds), -1))
def accuracy(inp, targ, axis=-1):
"Compute accuracy with `targ` when `pred` is bs * n_classes"
pred,targ = flatten_check(inp.argmax(dim=axis), targ.argmax(dim=axis))
return (pred == targ).float().mean()
for i in range(N_FOLDS):
gc.collect()
dls = get_dls(rotnet_df, 228, i, bs=32)
model = get_model(rotnet_df, True)
sm = SaveModelCallback(monitor='valid_loss', fname=f'{path}/model/rotnet_downstm_fold_{i}_best')
learn = Learner(dls,
model,
loss_func=CrossEntropyLossOneHot(),
metrics=[accuracy, RocAuc()],
cbs=[SoftLabelCB(rotnet_df)])
learn.fine_tune(25, base_lr=1e-3, reset_opt=True, cbs=sm, freeze_epochs=2)
Those are pretty amazing results!
References:
https://amarsaini.github.io/Epoching-Blog/jupyter/2020/03/23/Self-Supervision-with-FastAI.html
https://github.com/Isaac-Flath/fastblog/blob/master/_notebooks/2021-02-15-PlantPathology.ipynb
https://keremturgutlu.github.io/self_supervised/
https://www.kaggle.com/keremt/progressive-label-correction-paper-implementation
https://paperswithcode.com/method/simclr
https://medium.com/wicds/exploring-the-essence-of-simclr-8e205ebc77af