In this blog, we will learn/apply self-supervised learning and soft-label/pseudo-labelling technique to tackle the plant pathology Kaggle competition. I was introduced to soft-labeling when I came across this amazing blog by Isaac Flath. About the same time, I came across Jeremy's blog on self-supervised learning (SSL). Both the techniques were interesting and I feel going forward these techniques are going to be important techniques in computer vision. Hence, I decided to blog my learnings of these intriguing techniques.

A word on the dataset

For this blog, we will be using the Plant Pathology dataset. The task at hand is to classify images of different plant diseases.

There are some challenges with the dataset

  1. Imbalanced data - one of the class contains very few samples
  2. Mislabeled data - the other issue with this dataset is that the labels are noisy/mislabeled. The winner of this competition used soft-labelling in the winning model

We will do this in two parts

  1. Self-supervised learning

    1. What is self-supervised learning
    2. Different types of self-supervised learning
    3. Applying rotation-based learning (rotnet)
    4. Applying SimCLR learning
  2. Soft-labeling | pseudo-laebling

    1. Generating pseudo-label using transfer learning from rotnet model, SimCLR model
    2. Final task (downstream task) to predict plant diseases using transfer learning (using the model trained in part 1) and soft-labeling/progressive pseudo-labelling

Part 1: Self-Supervised Learning

What is Self-supervised learning?

Supervised learning has brought tremendous success to the field of computer vision. Although supervised learning requires huge amount of data, by using transfer learning one could reduce the requirement of data by about 1000x.

Largely, models for transfer learning are trained using the ImageNet dataset which were trained on 1.4 million images of 1000 different classes. Sometimes, for some domain, such as medical images, transfer learning from ImageNet might not work that well. Given this constraint, and the fact that generating labels for large amount of data in this domain is costly, self-supervised learning (SSL) has proven to be effective. In SSL, unlabeled data is trained in a supervised manner.

So what is self-supervised learning? SSL make use of the labels that are naturally part of the input data. SSL has been widely used in NLP tasks. Training a language model often involves predicting the next word of a sentence. The label for the language model, the next word, is a natural part of the input data. While learning to predict the next word, the model must have learnt a bit about the nature of language. Now, such training are common in most NLP tasks.

In computer vision, SSL starts with "pretext tasks". In "pretext task", we train a model from scratch using labels that comes naturally with the input data. Once pretrained, fine-tuning can be carried out on the "downstream tasks".

Further Readings

https://www.fast.ai/2020/01/13/self_supervised/

https://lilianweng.github.io/lil-log/2019/11/10/self-supervised-learning.html

Different ways of doing self-supervised learning

Different techniques of SSL can be categorized into three

  1. Pretext task based

    1. Relative positioning
    2. Colorization
    3. Rotation
    4. Multiple pretext
    5. Jigsaw puzzle
  2. Generative model based

    1. Autoencoders
    2. Split brain autoencoder
    3. Neural scene representation
    4. Context encoder
    5. Semantic inpainting
    6. BiGAN
  3. Discriminative based contrastive learning

    1. SimCLR
    2. SimCLR2
    3. MoCo
    4. MoCo v2
    5. BYOL
    6. SwAV

SSL and different techniques were well explained in this lecture series by Anuj Shah. Taking a look a the SOTA page for SSL on PapersWithCode. The discriminative based contrastive learning are leading the pack.

Rotation-based pretext tasks

In this pretext task, we rotate the image by certain degrees (0, 90, 180 and 270) and train the model to predict which degree of rotation was applied. While learning to classify the degree of rotation, the model learns many semantic concepts.

image.png (Image source: Gidaris et al. 2018) https://arxiv.org/abs/1803.07728

Now, let's take a look how we can apply this pretext task using fastai. For a more detailed explanation, please refer to this amazing blog by Amar Saini.

from fastai.vision.all import *
from fastai.vision.learner import _update_first_layer
import timm
import torchvision

path = Path('/content/drive/MyDrive/colab_notebooks/fastai/plant_pathology/data')
path_img = path/'images'
train = pd.read_csv(path/'train.csv')
train.head(5)

We will use seresnext50_32x4d from timm model as the architecture of choice.

def create_timm_body(arch:str, pretrained=False, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    model = timm.create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): return cut(model)
    else: raise NamedError("cut must be either integer or function")
arch = create_timm_body('seresnext50_32x4d')
nf = num_features_model(arch)
head = create_head(nf, 4, concat_pool=True)
net = nn.Sequential(arch, head)
tensorToImage = torchvision.transforms.ToPILImage()
imageToTensor = torchvision.transforms.ToTensor()

Once we have selected the encoder/base architecture, it's time to build our PyTorch style dataset. There is nothing fancy going on here except each training example is rotated randomly by 0, 90, 180 or 270 degrees. The degree of rotation will also be used as the label.

class Custom_Dataset_PP():
# Codes from Amar Saini's (__Epoching__) blog
    
    def __init__(self, fns):
        
        self.fns = fns
                    
    def __len__(self):
        return len(self.fns)
    
    def __getitem__(self, idx):
        
        # 4 classes for rotation
        degrees = [0, 90, 180, 270]
        rand_choice = random.randint(0, len(degrees)-1)
        
        img = PILImage.create(self.fns[idx])
        img = img.resize((256, 256))
        img = img.rotate(degrees[rand_choice])
        img = imageToTensor(img)
        
        return img, torch.tensor(rand_choice).long()
    
    def show_batch(self, n=3):
        fig, axs = plt.subplots(n, n, figsize=(12,10))
        fig.tight_layout()
        for i in range(n):
            for j in range(n):
                rand_idx = random.randint(0, len(self)-1)
                img, label = self.__getitem__(rand_idx)
                axs[i, j].imshow(tensorToImage(img), cmap='gray')
                axs[i, j].set_title('Label: {0} ({1} Degrees)'.format(label.item(), label.item()*90))
                axs[i, j].axis('off')

ds = Custom_Dataset_PP(path_img.ls())

As we can see from the show_batch, the rotation based pretext task might not be good for this dataset as leaves can exist in any orientations. But our objective is to learn hence this is still a good practice.

ds.show_batch()

From here, we use standard fastai practice to train the model.

split = int(len(path_img.ls())*0.8)
train_fns = path_img.ls()[:split]
valid_fns = path_img.ls()[split:]
train_ds = Custom_Dataset_PP(train_fns)
valid_ds = Custom_Dataset_PP(valid_fns)
dls = DataLoaders.from_dsets(train_ds, valid_ds).cuda()
learn = Learner(dls, net, loss_func=CrossEntropyLossFlat(), splitter=default_split, metrics=accuracy)
learn.fit_one_cycle(25, 1e-4)
epoch train_loss valid_loss accuracy time
0 1.569923 1.409840 0.271635 06:39
1 1.523075 1.378337 0.320913 03:21
2 1.461118 1.298200 0.352163 03:20
3 1.364290 1.330265 0.362981 03:21
4 1.270127 1.191683 0.418269 03:20
5 1.204014 2.085269 0.301683 03:19
6 1.129306 1.249018 0.411058 03:19
7 1.054507 1.541547 0.355769 03:18
8 0.996718 2.167537 0.304087 03:20
9 0.930008 2.989300 0.269231 03:20
10 0.888594 0.878343 0.522837 03:19
11 0.852690 0.867087 0.533654 03:20
12 0.810870 0.848070 0.554087 03:20
13 0.786548 1.054648 0.491587 03:21
14 0.739632 1.170150 0.473558 03:21
15 0.707019 0.832262 0.542067 03:21
16 0.671642 0.802649 0.585337 03:21
17 0.653890 0.827771 0.569712 03:21
18 0.621239 1.211577 0.491587 03:20
19 0.600797 0.804843 0.573317 03:22
20 0.581002 0.796838 0.580529 03:21
21 0.564937 0.779816 0.587740 03:20
22 0.564841 0.760423 0.579327 03:20
23 0.549947 0.773734 0.584135 03:20
24 0.533107 0.759421 0.608173 03:20

We will save the weights of the encoader.

torch.save(learn.model[0], f'{path}/seresnext50_32x4d_rotnetencoader.pth')

SimCLR: A Simple Framework for Contrastive Learning of Visual Representation

As we saw earlier in PapersWithCode SOTA page for SSL, discriminative based contrastive learning are among the most successful SSL techniques. SimCLR is one among these techniques that was introduced by Google Research.

Before we go further, let’s take a look at the definition of contrastive learning?

What is contrastive learning? It is a SSL technique to learn features of a dataset without labels by teaching the model which data points are similar and which are different.

The SimCLR framework proposes four key components for contrastive learning

  1. Data Augmentation

    Here, an image is randomly transformed into two correlated views. This forms a positive pair. SimCLR sequentially applies the following augmentations - random cropping, followed by resize back to the original size, random color distortions and random Gaussian blur. Two views coming from the same image would form a positive pair while two views coming from different image would form a negative pair.

  2. A neural network base encoder

    A model backbone that extracts the feature maps. In our case, we will be using seresnext50_32x4d from timm

  3. Neural network projection head

    This maps the feature maps into a vector space representation where the contrastive loss will be applied. Multilayer perceptron with single linear layer was used in the original paper.

  4. Contrastive loss

    A loss function designed to evaluate how good a job the network is at picking up positive pairs and negative pairs.

Let's apply SimCLR in fastai using this self-supervised library by Kerem Turgutlu.

from self_supervised.simclr import *

First, lets make a general dataloader as per usual.

train['labels'] = train.iloc[:, 1:].idxmax(axis=1)
fns = train.image_id.values
def label_fn(fn):
    return train[train['image_id'] == fn]['labels'].values[0]
    
label_fn(fns[0])
'scab'
sz = 256
ds = Datasets(fns, [[lambda x: f'{path_img}/{x}.jpg', PILImage.create], [label_fn, Categorize()]], splits=RandomSplitter()(fns))
dls = ds.dataloaders(bs=32, after_item=[ToTensor(), Resize(sz), IntToFloatTensor()])
dls.show_batch()

Next, we will make the model which would include an encoader and a projection head. For this we will make use of self-supervised library with some changes. The encoader will come from timm and then attach the projection head to the encoader.

def create_timm_body(arch:str, pretrained=False, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    model = timm.create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): return cut(model)
    else: raise NamedError("cut must be either integer or function")
class MLP(Module):
    "MLP module as described in paper"
    def __init__(self, dim, concat_pool=True, projection_size=128, hidden_size=256):
        self.pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)
        self.flatten = Flatten()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, projection_size)
        )

    def forward(self, x):
        x = self.pool(x)
        x = self.flatten(x)
        x = self.net(x) 
        return x
def create_simclr_model(arch=resnet50, n_in=3, pretrained=False, cut=None, concat_pool=True,
                      hidden_size=256, projection_size=128):
    "Create SimCLR from a given arch"
    encoder = create_timm_body(arch, pretrained, cut, n_in)
    with torch.no_grad(): representation = encoder(torch.randn((2, n_in, 128, 128)))    
    projector = MLP(representation.size(1)*2, projection_size, hidden_size=hidden_size)
    apply_init(projector)
    return SimCLRModel(encoder, projector)
model = create_simclr_model('seresnext50_32x4d', pretrained=False)

Once we have the dataloader and the model, we will make a learner. The learner use a SimCLRLoss and SimCLR callback function. We will look at these codes to understand what is going on.

learn = Learner(dls, model, 
                SimCLRLoss(temp=0.5),
                cbs=[SimCLR(size=128, color=False, stats=None)])

Below we have the codes for the SimCLR callback. The callback upon initiation makes two augmentation function - aug1 and aug2. These will be used to make the two views of the same input image. As you can see from the get_aug_pipe function the augmentation uses RandomResizedCrop, RandomHorizontalFlip, ColorJitter, and RandomGrayscale. The callback before_batch makes the two images, concats them and use that as the inputs. The labels are also changed. Assuming the initial batch_size of 5, the labels would be [5, 6, 7, 8, 9, 0, 1, 2, 3, 4]. The first image is positive pair of 5th image, the second image is positive pair of 6th image, and so on.

def get_aug_pipe(size, stats=imagenet_stats, s=.6, color=True, xtra_tfms=[]):
    "SimCLR augmentations"
    tfms = []
    tfms += [kornia.augmentation.RandomResizedCrop((size, size), scale=(0.2, 1.0), ratio=(3/4, 4/3))]
    tfms += [kornia.augmentation.RandomHorizontalFlip()]

    if color: tfms += [kornia.augmentation.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)]
    if color: tfms += [kornia.augmentation.RandomGrayscale(p=0.2)]
    tfms += xtra_tfms
    if stats is not None: tfms += [Normalize.from_stats(*stats)]

    pipe = Pipeline(tfms)
    pipe.split_idx = 0
    return pipe
class SimCLR(Callback):
    "SimCLR callback"
    def __init__(self, size=256, **aug_kwargs):
        self.aug1 = get_aug_pipe(size, **aug_kwargs)
        self.aug2 = get_aug_pipe(size, **aug_kwargs)

    def before_batch(self):
        xi,xj = self.aug1(self.x), self.aug2(self.x)
        self.learn.xb = (torch.cat([xi, xj]),)
        bs = self.learn.xb[0].shape[0]
        self.learn.yb = (torch.arange(bs, device=self.dls.device).roll(bs//2),)

    def show_one(self):
        xb = TensorImage(self.learn.xb[0])
        bs = len(xb)//2
        i = np.random.choice(bs)
        xb = self.aug1.decode(xb.to('cpu').clone()).clamp(0,1)
        images = [xb[i], xb[bs+i]]
        show_images(images)
b = dls.one_batch()
learn._split(b)
learn('before_batch')

The modified image inputs; the initial batch of 32 is doubled.

learn.xb[0].shape
torch.Size([64, 3, 128, 128])

Let's take a look at the modified labels.

learn.yb[0]
tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,  0,  1,  2,  3,
         4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31], device='cuda:0')

Now that we have taken a look at the callback, lets take a look at the loss function.

class SimCLRLoss(Module):
    "SimCLR loss function"
    def __init__(self, temp=0.1):
        self.temp = temp

    def forward(self, inp, targ):
        bs,feat = inp.shape
        csim = F.cosine_similarity(inp, inp.unsqueeze(dim=1), dim=-1)/self.temp
        csim = remove_diag(csim)
        targ = remove_diag(torch.eye(targ.shape[0], device=inp.device)[targ]).nonzero()[:,-1]
        return F.cross_entropy(csim, targ)

The loss function calculates the cosine_similarity of the vector representation from the projection head and then divides it by the hyperparameter temp. We remove the diagonals from the cosine similarity and the targets. Then we calculate the cross_entropy.

Let's fit_one_cycle.

learn.fit_one_cycle(25, 1e-4)
epoch train_loss valid_loss time
0 4.004565 4.004914 01:20
1 3.730691 3.483200 01:21
2 3.505424 3.356370 01:21
3 3.350685 3.181578 01:21
4 3.256876 3.138497 01:21
5 3.188638 3.078935 01:21
6 3.124358 3.023555 01:22
7 3.088448 3.021066 01:22
8 3.056739 3.010895 01:21
9 3.032721 3.060869 01:21
10 3.009059 3.009850 01:22
11 2.992094 2.977447 01:22
12 2.992314 2.997800 01:21
13 2.974692 3.005455 01:21
14 2.962943 2.972467 01:22
15 2.954012 2.910474 01:21
16 2.947802 2.915481 01:21
17 2.944147 2.911163 01:21
18 2.933424 2.912643 01:21
19 2.931176 2.908795 01:21
20 2.928176 2.912873 01:21
21 2.927793 2.922826 01:21
22 2.928530 2.908896 01:21
23 2.922103 2.900265 01:21
24 2.918781 2.921983 01:22

Similar to out rotnet, we will save the weights of the encoader.

torch.save(learn.model.encoder, f'{path}/seresnext50_32x4d_simclrencoader.pth')

Part 2: Softlabelling

Now that we have pre-trained our network using SSL techniques, we will move on to the downstream task. For our downstream task, we will also use soft-labeling to help our training.

Soft-labeling is very helpful in cases where the labels are noisy. The dataset we have been using, Plant Pathology, suffers from noisy labels. The winning solution for this competition used soft-labeling. I came across soft-labeling through Isaac's blog. For more comprehensive application of soft-labeling, please refer to Isaac's blog.

We will do this part in two steps

Step 1: Generating soft-labels

Step 2: Applying soft-labels in the downstream task

Generating softlabels

from sklearn.model_selection import StratifiedKFold

So to apply soft-labeling, we will have to generate the pseudo-labels. We will use SSL trained models to generate our pseudo-labels. The following is the steps to make pseudo-labels.

  1. Create kfold. In our case, we will use k=2
  2. Train the model using SSL trained weights and fit on the training data with the given labels
  3. After sufficient training, we will generate prediction on the valid set. The prediction will be used as the pseudo-labels
  4. Repeat step 2 and 3 for all different k valid sets
train.head(3)
image_id healthy multiple_diseases rust scab
0 Train_0 0 0 0 1
1 Train_1 0 1 0 0
2 Train_2 1 0 0 0
train['labels'] = train.iloc[:, 1:].idxmax(1)
N_FOLDS = 2
train['fold'] = -1

strat_kfold = StratifiedKFold(n_splits=N_FOLDS, random_state=42, shuffle=True)
for i, (_, test_index) in enumerate(strat_kfold.split(train.image_id.values, train['labels'].values)):
    train.iloc[test_index, -1] = i
    
train['fold'] = train['fold'].astype('int')
def get_dls(df, size, fold, bs):

    batch_tfms = [Normalize.from_stats(*imagenet_stats)]
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=IndexSplitter(df.loc[df.fold==fold].index),
                       getters=[
                          ColReader('image_id', pref=path/'images', suff='.jpg'),
                          ColReader('labels')
                              ],
                       item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)],
                       batch_tfms=batch_tfms)
    return dblock.dataloaders(df, bs=bs)
def get_model(df, rotnet=True):

    if rotnet:
        arch = create_timm_body('seresnext50_32x4d')
        torch.load(f'{path}/seresnext50_32x4d_rotnetencoader.pth')
        head = create_head(num_features_model(arch), df['labels'].nunique())
        apply_init(head)
        model = nn.Sequential(arch, head)
    else:
        arch = create_timm_body('seresnext50_32x4d')
        torch.load(f'{path}/seresnext50_32x4d_simclrencoader.pth')
        head = create_head(num_features_model(arch), df['labels'].nunique())
        apply_init(head)
        model = nn.Sequential(arch, head)

    return model
splits, preds, targs, preds_c,  = [],[],[],[]
items = pd.DataFrame(columns = train.columns)

for i in range(N_FOLDS):

    dls = get_dls(train, 228, i, bs=16)
    model = get_model(train, rotnet=False)
    learn = Learner(dls, model, metrics=[accuracy, RocAuc()])
    learn.fine_tune(15, reset_opt=True)
    
    # store predictions
    p, t, c = learn.get_preds(ds_idx=1, with_decoded=True)
    preds.append(p); targs.append(t); preds_c.append(c); 
    items = pd.concat([items, dls.valid.items])
epoch train_loss valid_loss accuracy roc_auc_score time
0 2.245435 2.073756 0.407245 0.591565 01:27
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.630546 1.246244 0.506037 0.681683 01:28
1 1.536798 1.205274 0.535675 0.697780 01:27
2 1.576910 1.541064 0.473106 0.684111 01:27
3 1.486408 1.448349 0.553238 0.696692 01:27
4 1.408332 1.150345 0.531284 0.731727 01:27
5 1.292639 0.882109 0.663008 0.804221 01:26
6 1.103595 1.071286 0.611416 0.795798 01:28
7 0.991075 0.743046 0.710209 0.824435 01:28
8 0.845321 0.623146 0.765093 0.851047 01:27
9 0.790943 0.619872 0.750823 0.874319 01:31
10 0.686515 0.568958 0.791438 0.891836 01:28
11 0.655953 0.519997 0.818880 0.897981 01:30
12 0.597022 0.477010 0.830955 0.910269 01:28
13 0.531461 0.475933 0.829857 0.908447 01:27
14 0.502115 0.476578 0.828760 0.906857 01:27
epoch train_loss valid_loss accuracy roc_auc_score time
0 2.179506 2.657610 0.390110 0.607213 01:26
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.620119 1.272247 0.509890 0.683388 01:26
1 1.551659 1.186581 0.525275 0.723610 01:25
2 1.457448 1.184272 0.515385 0.723678 01:24
3 1.449472 1.069529 0.583516 0.748897 01:27
4 1.385499 1.160630 0.567033 0.734892 01:26
5 1.177015 1.050339 0.606593 0.790950 01:25
6 1.101527 0.718146 0.713187 0.833402 01:25
7 1.001735 0.937197 0.668132 0.813787 01:25
8 0.915719 0.711719 0.735165 0.857576 01:26
9 0.798974 0.606332 0.772527 0.870674 01:26
10 0.707872 0.535294 0.794505 0.914215 01:26
11 0.652735 0.533704 0.804396 0.901796 01:25
12 0.597695 0.484282 0.832967 0.915541 01:25
13 0.559858 0.475080 0.826374 0.920336 01:25
14 0.542785 0.470118 0.825275 0.920716 01:24
imgs = L(o for o in items.image_id.values)
y_true = L(o for o in items.labels.values) 
y_targ = L(dls.vocab[o] for o in torch.cat(targs)) 
y_pred = L(dls.vocab[o] for o in torch.cat(preds_c)) 
p_max = torch.cat(preds).max(dim=1)[0]
res = pd.DataFrame({'imgs':imgs,'y_true':y_true,'y_pred':y_pred}).set_index('imgs')
print(res.shape)
print(train.shape)
res.sample(5)
(1821, 2)
(1821, 7)
y_true y_pred
imgs
Train_1305 scab scab
Train_544 healthy healthy
Train_1286 healthy healthy
Train_814 rust rust
Train_1694 healthy healthy
res.to_csv(f'{path}/train_simclr_sl.csv')

Final downstream tasks

For our final task, we will train using the SSL trained models and use the pseudolebels we generated to 'reduce' the impact of nosiy labels.

# dataframe with softlabels
rotnet_df = pd.read_csv(f'{path}/train_rotnet_sl.csv')
simclr_df = pd.read_csv(f'{path}/train_simclr_sl.csv')
rotnet_df.columns = ['image_id', 'labels', 'softlabels']
simclr_df.columns = ['image_id', 'labels', 'softlabels']
from sklearn.model_selection import StratifiedKFold
import gc
N_FOLDS = 3
rotnet_df['fold'] = -1

strat_kfold = StratifiedKFold(n_splits=N_FOLDS, random_state=42, shuffle=True)
for i, (_, test_index) in enumerate(strat_kfold.split(rotnet_df.image_id.values, rotnet_df['labels'].values)):
    rotnet_df.iloc[test_index, -1] = i
    
rotnet_df['fold'] = rotnet_df['fold'].astype('int')
N_FOLDS = 3
simclr_df['fold'] = -1

strat_kfold = StratifiedKFold(n_splits=N_FOLDS, random_state=42, shuffle=True)
for i, (_, test_index) in enumerate(strat_kfold.split(simclr_df.image_id.values, simclr_df['labels'].values)):
    simclr_df.iloc[test_index, -1] = i
    
simclr_df['fold'] = simclr_df['fold'].astype('int')
class SoftLabelCB(Callback):
    def __init__(self, df_preds, y_true_weight = 0.5): 
        '''df_preds is a pandas dataframe where index is image paths
             Must have y_true and y_pred one hot encoded columns (ie y_true_0, y_true_1)
          '''
        
        self.y_true_weight = y_true_weight
        self.y_pred_weight = 1 - y_true_weight
        self.df = pd.get_dummies(df_preds, columns=['labels', 'softlabels'])

    def before_train(self):
        if type(self.dl.items)==type(pd.DataFrame()): self.idx_list = L(o for o in self.dl.items.index.values)
        if is_listy(self.dl.items): self.imgs_list = L(self.dl.items)      
    
    def before_validate(self):
        if type(self.dl.items)==type(pd.DataFrame()): self.idx_list = L(o for o in self.dl.items.index.values)
        if is_listy(self.dl.items): self.imgs_list = L(self.dl.items)       
    
    def before_batch(self):
        # get the images' names for the current batch
        idx = self.idx_list[self.dl._DataLoader__idxs[self.iter*self.dl.bs:self.iter*self.dl.bs+self.dl.bs]]
        
        # get soft labels
        df = self.df
        soft_labels = df.loc[idx,df.columns.str.startswith('labels')].values
        
        if self.training:
            soft_labels = soft_labels*self.y_true_weight + df.loc[idx,df.columns.str.startswith('softlabels')].values*self.y_pred_weight
        self.learn.yb = (Tensor(soft_labels).cuda(),)

class CrossEntropyLossOneHot(nn.Module):
    def __init__(self):
        super(CrossEntropyLossOneHot, self).__init__()
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, preds, labels):
        return torch.mean(torch.sum(-labels * self.log_softmax(preds), -1))

def accuracy(inp, targ, axis=-1):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    pred,targ = flatten_check(inp.argmax(dim=axis), targ.argmax(dim=axis))
    return (pred == targ).float().mean()
for i in range(N_FOLDS):
    
    gc.collect()

    dls = get_dls(rotnet_df, 228, i, bs=32)
    model = get_model(rotnet_df, True)

    sm = SaveModelCallback(monitor='valid_loss', fname=f'{path}/model/rotnet_downstm_fold_{i}_best')

    learn = Learner(dls, 
                    model,
                    loss_func=CrossEntropyLossOneHot(),
                    metrics=[accuracy, RocAuc()],
                    cbs=[SoftLabelCB(rotnet_df)])

    learn.fine_tune(25, base_lr=1e-3, reset_opt=True, cbs=sm, freeze_epochs=2)
epoch train_loss valid_loss accuracy roc_auc_score time
0 2.180979 1.798362 0.400330 0.617794 06:34
1 1.896754 1.784741 0.444811 0.667562 01:35
Better model found at epoch 0 with valid_loss value: 1.798362374305725.
Better model found at epoch 1 with valid_loss value: 1.7847411632537842.
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.484100 1.153402 0.556837 0.730138 01:35
1 1.320654 1.144631 0.570017 0.770639 01:34
2 1.314283 1.177314 0.561779 0.756559 01:35
3 1.311820 1.222077 0.563427 0.743319 01:35
4 1.308478 1.284149 0.558484 0.765968 01:35
5 1.241711 1.096887 0.596376 0.773927 01:35
6 1.205534 1.040225 0.611203 0.790580 01:33
7 1.109477 1.162204 0.619440 0.773357 01:35
8 1.038942 0.975802 0.647446 0.827652 01:35
9 0.926236 0.790048 0.728171 0.870287 01:35
10 0.850851 0.666998 0.757825 0.879508 01:35
11 0.766376 0.633047 0.775947 0.898495 01:35
12 0.690296 0.546291 0.785832 0.924340 01:35
13 0.646431 0.558037 0.790774 0.916405 01:36
14 0.617734 0.487432 0.823723 0.928350 01:35
15 0.578832 0.504875 0.827018 0.922318 01:34
16 0.555530 0.594202 0.812191 0.901316 01:35
17 0.524157 0.443095 0.836903 0.941674 01:35
18 0.490369 0.416304 0.851730 0.942821 01:35
19 0.476614 0.419605 0.855025 0.944162 01:36
20 0.453619 0.403086 0.850082 0.946465 01:35
21 0.438267 0.398012 0.863262 0.945645 01:36
22 0.421713 0.392551 0.858320 0.947091 01:36
23 0.406884 0.391196 0.858320 0.947049 01:36
24 0.416799 0.393447 0.855025 0.948277 01:34
Better model found at epoch 0 with valid_loss value: 1.1534024477005005.
Better model found at epoch 1 with valid_loss value: 1.1446309089660645.
Better model found at epoch 5 with valid_loss value: 1.096887469291687.
Better model found at epoch 6 with valid_loss value: 1.0402253866195679.
Better model found at epoch 8 with valid_loss value: 0.9758020043373108.
Better model found at epoch 9 with valid_loss value: 0.7900478839874268.
Better model found at epoch 10 with valid_loss value: 0.6669976115226746.
Better model found at epoch 11 with valid_loss value: 0.6330470442771912.
Better model found at epoch 12 with valid_loss value: 0.546291172504425.
Better model found at epoch 14 with valid_loss value: 0.48743197321891785.
Better model found at epoch 17 with valid_loss value: 0.4430951774120331.
Better model found at epoch 18 with valid_loss value: 0.41630420088768005.
Better model found at epoch 20 with valid_loss value: 0.40308573842048645.
Better model found at epoch 21 with valid_loss value: 0.3980117738246918.
Better model found at epoch 22 with valid_loss value: 0.392551451921463.
Better model found at epoch 23 with valid_loss value: 0.39119648933410645.
epoch train_loss valid_loss accuracy roc_auc_score time
0 2.077682 1.358871 0.421746 0.630542 01:35
1 1.890515 1.920299 0.433278 0.624589 01:35
Better model found at epoch 0 with valid_loss value: 1.3588707447052002.
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.721737 1.420372 0.471170 0.660035 01:34
1 1.556469 1.232172 0.560132 0.708895 01:35
2 1.505628 1.309458 0.560132 0.715855 01:35
3 1.521256 1.261295 0.537068 0.709441 01:35
4 1.445887 1.404649 0.542010 0.675589 01:35
5 1.421579 1.613186 0.505766 0.688969 01:35
6 1.341298 1.260396 0.542010 0.698050 01:34
7 1.266212 1.263892 0.599671 0.739613 01:35
8 1.192686 1.052656 0.624382 0.782726 01:35
9 1.123260 0.921180 0.640857 0.814421 01:35
10 1.071517 0.980215 0.649094 0.796601 01:35
11 1.003980 0.835629 0.698517 0.825559 01:35
12 0.914666 1.014547 0.685338 0.805289 01:35
13 0.815322 0.875604 0.696870 0.832545 01:35
14 0.755339 0.692518 0.754530 0.861947 01:35
15 0.682022 0.717943 0.764415 0.860634 01:35
16 0.627511 0.689097 0.752883 0.853841 01:35
17 0.552456 0.604317 0.797364 0.871979 01:36
18 0.551212 0.627598 0.803954 0.880846 01:35
19 0.544026 0.636285 0.803954 0.873143 01:35
20 0.505864 0.585278 0.808896 0.886988 01:35
21 0.483512 0.540007 0.828666 0.894359 01:35
22 0.474091 0.579923 0.820428 0.883195 01:35
23 0.456476 0.580129 0.825371 0.882503 01:35
24 0.450166 0.562739 0.827018 0.885534 01:35
Better model found at epoch 0 with valid_loss value: 1.4203722476959229.
Better model found at epoch 1 with valid_loss value: 1.232171893119812.
Better model found at epoch 8 with valid_loss value: 1.052655577659607.
Better model found at epoch 9 with valid_loss value: 0.9211796522140503.
Better model found at epoch 11 with valid_loss value: 0.8356290459632874.
Better model found at epoch 14 with valid_loss value: 0.6925175189971924.
Better model found at epoch 16 with valid_loss value: 0.6890971064567566.
Better model found at epoch 17 with valid_loss value: 0.6043166518211365.
Better model found at epoch 20 with valid_loss value: 0.5852775573730469.
Better model found at epoch 21 with valid_loss value: 0.5400067567825317.
epoch train_loss valid_loss accuracy roc_auc_score time
0 2.081328 1.790862 0.342669 0.594329 01:33
1 1.812289 1.978368 0.461285 0.698615 01:35
Better model found at epoch 0 with valid_loss value: 1.7908616065979004.
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.681780 1.351755 0.423394 0.678330 01:35
1 1.628255 1.253813 0.533773 0.726907 01:35
2 1.514754 1.179286 0.543657 0.744522 01:35
3 1.446746 1.149792 0.537068 0.747809 01:35
4 1.452345 1.244986 0.546952 0.685498 01:35
5 1.401141 1.286189 0.560132 0.746649 01:35
6 1.342116 1.166848 0.556837 0.741916 01:35
7 1.276286 1.104091 0.593081 0.772318 01:33
8 1.199893 1.050359 0.614498 0.784405 01:35
9 1.111698 0.914566 0.640857 0.823923 01:35
10 1.059557 0.892816 0.665568 0.839800 01:35
11 0.988732 0.863787 0.688633 0.842417 01:35
12 0.886488 0.689515 0.746293 0.871834 01:35
13 0.803669 0.739485 0.744646 0.869706 01:35
14 0.719890 0.716404 0.757825 0.902299 01:35
15 0.649968 0.674304 0.775947 0.899190 01:34
16 0.613759 0.650748 0.761120 0.902635 01:33
17 0.573795 0.488263 0.835255 0.932428 01:35
18 0.549245 0.499743 0.831960 0.928918 01:35
19 0.518839 0.460697 0.835255 0.935363 01:35
20 0.480734 0.425933 0.851730 0.945728 01:35
21 0.465403 0.426006 0.855025 0.945700 01:35
22 0.442722 0.419363 0.853377 0.947554 01:35
23 0.466067 0.409462 0.861615 0.949312 01:35
24 0.467920 0.411176 0.866557 0.949608 01:35
Better model found at epoch 0 with valid_loss value: 1.3517550230026245.
Better model found at epoch 1 with valid_loss value: 1.2538130283355713.
Better model found at epoch 2 with valid_loss value: 1.1792856454849243.
Better model found at epoch 3 with valid_loss value: 1.1497923135757446.
Better model found at epoch 7 with valid_loss value: 1.104090929031372.
Better model found at epoch 8 with valid_loss value: 1.0503588914871216.
Better model found at epoch 9 with valid_loss value: 0.9145664572715759.
Better model found at epoch 10 with valid_loss value: 0.8928159475326538.
Better model found at epoch 11 with valid_loss value: 0.8637869358062744.
Better model found at epoch 12 with valid_loss value: 0.6895149350166321.
Better model found at epoch 15 with valid_loss value: 0.6743040680885315.
Better model found at epoch 16 with valid_loss value: 0.650747537612915.
Better model found at epoch 17 with valid_loss value: 0.48826295137405396.
Better model found at epoch 19 with valid_loss value: 0.4606971740722656.
Better model found at epoch 20 with valid_loss value: 0.4259330928325653.
Better model found at epoch 22 with valid_loss value: 0.419363409280777.
Better model found at epoch 23 with valid_loss value: 0.4094623327255249.

Those are pretty amazing results!