Kornia and PyTorch Lightning GPU data augmentation

Basic
Data augmentation
Pytorch lightning
kornia.augmentation
In this tutorial we show how one can combine both Kornia and PyTorch Lightning to perform data augmentation to train a model using CPUs and GPUs in batch mode without additional effort.
Author

Edgar Riba

Published

March 18, 2021

Open in google colab

Install Kornia and PyTorch Lightning

We first install Kornia and PyTorch Lightning

%%capture
!pip install kornia
!pip install kornia-rs
!pip install pytorch_lightning torchmetrics

Import the needed libraries

import os

import kornia as K
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
from PIL import Image
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

Define Data Augmentations module

class DataAugmentation(nn.Module):
    """Module to perform data augmentation using Kornia on torch tensors."""

    def __init__(self, apply_color_jitter: bool = False) -> None:
        super().__init__()
        self._apply_color_jitter = apply_color_jitter

        self._max_val: float = 255.0

        self.transforms = nn.Sequential(K.enhance.Normalize(0.0, self._max_val), K.augmentation.RandomHorizontalFlip(p=0.5))

        self.jitter = K.augmentation.ColorJitter(0.5, 0.5, 0.5, 0.5)

    @torch.no_grad()  # disable gradients for effiency
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_out = self.transforms(x)  # BxCxHxW
        if self._apply_color_jitter:
            x_out = self.jitter(x_out)
        return x_out

Define a Pre-processing model

class PreProcess(nn.Module):
    """Module to perform pre-process using Kornia on torch tensors."""

    def __init__(self) -> None:
        super().__init__()

    @torch.no_grad()  # disable gradients for effiency
    def forward(self, x: Image) -> torch.Tensor:
        x_tmp: np.ndarray = np.array(x)  # HxWxC
        x_out: torch.Tensor = K.image_to_tensor(x_tmp, keepdim=True)  # CxHxW
        return x_out.float()

Define PyTorch Lightning model

class CoolSystem(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(3 * 32 * 32, 10)

        self.preprocess = PreProcess()

        self.transform = DataAugmentation()

        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        x_aug = self.transform(x)  # => we perform GPU/Batched data augmentation
        logits = self.forward(x_aug)
        loss = F.cross_entropy(logits, y)
        self.log("train_acc_step", self.accuracy(logits.argmax(1), y))
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        logits = self.forward(x)
        self.log("val_acc_step", self.accuracy(logits.argmax(1), y))
        return F.cross_entropy(logits, y)

    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        logits = self.forward(x)
        acc = self.accuracy(logits.argmax(1), y)
        self.log("test_acc_step", acc)
        return acc

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.0004)

    def prepare_data(self):
        CIFAR10(os.getcwd(), train=True, download=True, transform=self.preprocess)
        CIFAR10(os.getcwd(), train=False, download=True, transform=self.preprocess)

    def train_dataloader(self):
        # REQUIRED
        dataset = CIFAR10(os.getcwd(), train=True, download=False, transform=self.preprocess)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        return loader

    def val_dataloader(self):
        dataset = CIFAR10(os.getcwd(), train=True, download=False, transform=self.preprocess)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        return loader

    def test_dataloader(self):
        dataset = CIFAR10(os.getcwd(), train=False, download=False, transform=self.preprocess)
        loader = DataLoader(dataset, batch_size=16, num_workers=1)
        return loader

Run training

from pytorch_lightning import Trainer

# init model
model = CoolSystem()

# Initialize a trainer
accelerator = "cpu"  # can be 'gpu'

trainer = Trainer(accelerator=accelerator, max_epochs=1, enable_progress_bar=False)

# Train the model ⚡
trainer.fit(model)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name       | Type               | Params
--------------------------------------------------
0 | l1         | Linear             | 30.7 K
1 | preprocess | PreProcess         | 0     
2 | transform  | DataAugmentation   | 0     
3 | accuracy   | MulticlassAccuracy | 0     
--------------------------------------------------
30.7 K    Trainable params
0         Non-trainable params
30.7 K    Total params
0.123     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=1` reached.

Test the model

trainer.test(model)
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_acc_step         0.10000000149011612
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'test_acc_step': 0.10000000149011612}]

Visualize

# # Start tensorboard.
# %load_ext tensorboard
# %tensorboard --logdir lightning_logs/