%%capture
!pip install kornia
!pip install kornia-rs
!pip install pytorch_lightning torchmetrics
Kornia and PyTorch Lightning GPU data augmentation
Basic
Data augmentation
Pytorch lightning
kornia.augmentation
In this tutorial we show how one can combine both Kornia and PyTorch Lightning to perform data augmentation to train a model using CPUs and GPUs in batch mode without additional effort.
Install Kornia and PyTorch Lightning
We first install Kornia and PyTorch Lightning
Import the needed libraries
import os
import kornia as K
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
from PIL import Image
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
Define Data Augmentations module
class DataAugmentation(nn.Module):
"""Module to perform data augmentation using Kornia on torch tensors."""
def __init__(self, apply_color_jitter: bool = False) -> None:
super().__init__()
self._apply_color_jitter = apply_color_jitter
self._max_val: float = 255.0
self.transforms = nn.Sequential(K.enhance.Normalize(0.0, self._max_val), K.augmentation.RandomHorizontalFlip(p=0.5))
self.jitter = K.augmentation.ColorJitter(0.5, 0.5, 0.5, 0.5)
@torch.no_grad() # disable gradients for effiency
def forward(self, x: torch.Tensor) -> torch.Tensor:
= self.transforms(x) # BxCxHxW
x_out if self._apply_color_jitter:
= self.jitter(x_out)
x_out return x_out
Define a Pre-processing model
class PreProcess(nn.Module):
"""Module to perform pre-process using Kornia on torch tensors."""
def __init__(self) -> None:
super().__init__()
@torch.no_grad() # disable gradients for effiency
def forward(self, x: Image) -> torch.Tensor:
= np.array(x) # HxWxC
x_tmp: np.ndarray = K.image_to_tensor(x_tmp, keepdim=True) # CxHxW
x_out: torch.Tensor return x_out.float()
Define PyTorch Lightning model
class CoolSystem(pl.LightningModule):
def __init__(self):
super().__init__()
# not the best model...
self.l1 = torch.nn.Linear(3 * 32 * 32, 10)
self.preprocess = PreProcess()
self.transform = DataAugmentation()
self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
# REQUIRED
= batch
x, y = self.transform(x) # => we perform GPU/Batched data augmentation
x_aug = self.forward(x_aug)
logits = F.cross_entropy(logits, y)
loss self.log("train_acc_step", self.accuracy(logits.argmax(1), y))
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
# OPTIONAL
= batch
x, y = self.forward(x)
logits self.log("val_acc_step", self.accuracy(logits.argmax(1), y))
return F.cross_entropy(logits, y)
def test_step(self, batch, batch_idx):
# OPTIONAL
= batch
x, y = self.forward(x)
logits = self.accuracy(logits.argmax(1), y)
acc self.log("test_acc_step", acc)
return acc
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
return torch.optim.Adam(self.parameters(), lr=0.0004)
def prepare_data(self):
=True, download=True, transform=self.preprocess)
CIFAR10(os.getcwd(), train=False, download=True, transform=self.preprocess)
CIFAR10(os.getcwd(), train
def train_dataloader(self):
# REQUIRED
= CIFAR10(os.getcwd(), train=True, download=False, transform=self.preprocess)
dataset = DataLoader(dataset, batch_size=32, num_workers=1)
loader return loader
def val_dataloader(self):
= CIFAR10(os.getcwd(), train=True, download=False, transform=self.preprocess)
dataset = DataLoader(dataset, batch_size=32, num_workers=1)
loader return loader
def test_dataloader(self):
= CIFAR10(os.getcwd(), train=False, download=False, transform=self.preprocess)
dataset = DataLoader(dataset, batch_size=16, num_workers=1)
loader return loader
Run training
from pytorch_lightning import Trainer
# init model
= CoolSystem()
model
# Initialize a trainer
= "cpu" # can be 'gpu'
accelerator
= Trainer(accelerator=accelerator, max_epochs=1, enable_progress_bar=False)
trainer
# Train the model ⚡
trainer.fit(model)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
| Name | Type | Params
--------------------------------------------------
0 | l1 | Linear | 30.7 K
1 | preprocess | PreProcess | 0
2 | transform | DataAugmentation | 0
3 | accuracy | MulticlassAccuracy | 0
--------------------------------------------------
30.7 K Trainable params
0 Non-trainable params
30.7 K Total params
0.123 Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=1` reached.
Test the model
trainer.test(model)
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_acc_step 0.10000000149011612
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'test_acc_step': 0.10000000149011612}]
Visualize
# # Start tensorboard.
# %load_ext tensorboard
# %tensorboard --logdir lightning_logs/