Deep Transfer Learning Tutorial in PyTorch on Animals-10 Dataset

Deep Transfer Learning Tutorial in PyTorch on Animals-10 Dataset
Prediction on Animals-10 dataset using deep transfer learning 

Deep learning on images is extremely popular and every few months there is a new algorithm that pushes the limits on prediction accuracy. Just a few years ago these state of the art algorithms and the knowledge behind how to build them were available to a few research experts. However the landscape is changing due to increasingly accessible and low cost computational resources, as well as open-access codes that anyone, from anywhere in the world can run. There are some fantastic folks willing to share their knowledge and expertise for the greater good.

In this tutorial, I am going to walk through the application of one such state of the art algorithm, Resnet-18 to predict animals from 10 classes using the popular deep learning framework PyTorch which was developed by Facebook's AI Research lab. I'll show how transfer learning can achieve very high accuracies (97% here, on test data).

Importing and Formatting

import os
import numpy as np
import torch
from torch import nn
from torch import optim 
from torch.autograd import Variable
import torch.utils.data as data
import torchvision
from torchvision import models
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.datasets import ImageFolder

res_18_model = models.resnet18(pretrained=True)

Notice at the end I import the resnet18 model from PyTorch and select pretrained=True. This is because for transfer learning, we typically use (some or most) of the weights learned by the deep learning algorithm on other data. Then we modify a few layers to apply to our case.

T = transforms.Compose([
     transforms.Resize((224,224)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

We need to download the images from kaggle and unzip it. Next, we transform the dataset into a standardized format using the PyTorch inbuilt transforms framework. This allows us to resize the image into a certain standardized format (224x224 in this case), and do other pre-processing. For transfer learning models, it is important to standardize input data to match what the original model was trained on (as much as possible). Otherwise you risk horrible performance (by experience :) ). For Resnet-18, the official PyTorch documentation suggests transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]). The first array consisting of 3 elements corresponds to the mean across the 3 RGB channels and the second array also consisting of 3 elements corresponds to the standard deviation across the same 3 channels.

Loading and Pre-processing Data

#download the dataset from here: https://www.kaggle.com/alessiocorrado99/animals10 

dataset = ImageFolder('./archive/raw-img/', transform=T)
train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset)*.8), len(dataset)-int(len(dataset)*.8)])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)
test_loader = torch.utils.data.DataLoader(val_set, batch_size=64)

Next, we load the data from the respective image folder, apply the transformations, and randomly split into 80% train and 20% test. Now we are ready to train (almost!) There's an issue though... If you look at the Resnet-18 architecture it looks pretty complex - this is just the first few layers. There are 18 layers in total doing various things (convolution, batchnorm, relu, maxpool, etc...) This is good because we want a complex model right? (As long as it doesn't take too long to run)

However, if you look at the last layer, it has 1000 features. This is because the Resnet-18 model was originally trained to predict on 1000 classes. This does not match our animals dataset which has 10 classes.

Last 2 Resnet-18 layers

So we need to change this last layer at the very least. It turns out to be quite simple as below:

res_18_model.fc= nn.Linear(512, 10)
Resnet-18 after modifying last layer to have 10 features corresponding to 10 classes

Alright, now we can start the training!

Model Training and Evaluation

if(torch.cuda.is_available()==True):
    model=res_18_model.cuda()
    
optimiser=optim.SGD(model.parameters(),lr=1e-2)
loss=nn.CrossEntropyLoss()

I have an NVIDIA GPU, so I make it available for training accordingly. The 2 important variables here are the optimizer and loss functions. For the optimizer I use the stochastic gradient descent function from PyTorch. I use the cross entropy here as it is a commonly used loss function for multi-class prediction.

# My training and validation loops
nb_epochs = 5
acc_tot=np.zeros(nb_epochs)
for epoch in range(nb_epochs):
    losses = list()
    accuracies = list()
    model.train()     
    for batch in train_loader: 

        x,y = batch
        if(torch.cuda.is_available()==True):
            x=x.cuda()
            y=y.cuda()        


        # 1 forward
        l = model(x) # l: logits

        #2 compute the objective function
        J = loss(l,y)

        # 3 cleaning the gradients
        model.zero_grad()
        # optimiser.zero_grad()
        # params.grad.zero_()

        # 4 accumulate the partial derivatives of J wrt params
        J.backward()

        # 5 step in the opposite direction of the gradient
        optimiser.step()



        losses.append(J.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1)).float().mean())

    print(f'Epoch {epoch + 1}', end=', ')
    print(f'training loss: {torch.tensor(losses).mean():.2f}', end=', ')
    print(f'training accuracy: {torch.tensor(accuracies).mean():.2f}')


    losses = list()
    accuracies = list() 
    model.eval()
    for batch in test_loader: 
        x,y = batch
        if(torch.cuda.is_available()==True):
            x=x.cuda()
            y=y.cuda()

        with torch.no_grad(): 
            l = model(x)

        #2 compute the objective function
        J = loss(l,y)

        losses.append(J.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1)).float().mean())

    print(f'Epoch {epoch + 1}',end=', ')
    print(f'validation loss: {torch.tensor(losses).mean():.2f}', end=', ')
    print(f'validation accuracy: {torch.tensor(accuracies).mean():.2f}')
    acc_tot[epoch]=torch.tensor(accuracies).mean().numpy()

Visualization

Great! Now how can we visualize our amazing model? Since this is a dataset of images, visualizing the results is particularly powerful. First, I define a function for converting images to the right format after the initial normalization.

def imformat(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return(inp)

The data comes with a dictionary — which I use below. Turns out this dataset has labels that are not in English. Let’s convert that for ease in understanding labels.

class_names = dataset.classes
translate = {"cane": "dog", "cavallo": "horse", "elefante": "elephant", "farfalla": "butterfly", "gallina": "chicken", "gatto": "cat", "mucca": "cow", "pecora": "sheep", "scoiattolo": "squirrel", "dog": "cane", "cavallo": "horse", "elephant" : "elefante", "butterfly": "farfalla", "chicken": "gallina", "cat": "gatto", "cow": "mucca", "spider": "ragno", "squirrel": "scoiattolo"}
t_inv = {v: k for k, v in translate.items()}

Finally, let us visualize the results!

train_loader2 = torch.utils.data.DataLoader(train_set, batch_size=9)

plt.figure(figsize=(15, 13))

inputs, classes = next(iter(train_loader2))
preds=model(inputs.cuda()).argmax(dim=1)


for i in range(0,9):
    ax = plt.subplot(3, 3, i + 1)
    img=imformat(inputs[i])
    
    plt.imshow((img))

    try:
        plt.title('True:'+str(t_inv[class_names[classes[i]]])+'    Pred:'+str(t_inv[class_names[preds[i]]]))
    except:
        plt.title('True:'+str(translate[class_names[classes[i]]])+'    Pred:'+str(translate[class_names[preds[i]]]))
    if(i==9):
        plt.axis("off")

And success! It is quite powerful to see the results of your well trained algorithm! It means that all the parts have been well put together! Here's the code on Github. Happy deep transfer learning!

Sources:

  1. https://www.youtube.com/watch?v=OMDn66kM9Qc&ab_channel=PyTorchLightning
  2. https://pytorch.org/vision/main/generated/torchvision.transforms.Normalize.html
  3. https://www.kaggle.com/alessiocorrado99/animals10

Thanks for reading! For Data Science and Machine Learning mentoring, please contact us! We develop custom learning pathways for individual clients and enable cutting edge AI based research. We also provide access to high-end computational servers based on needs.