%%capture
%matplotlib inline
!pip install kornia
!pip install kornia-rs
!pip install opencv-python
Data Augmentation Semantic Segmentation
Basic
2D
Segmentation
Data augmentation
kornia.augmentation
In this tutorial we will show how we can quickly perform data augmentation for semantic segmentation using the
kornia.augmentation
API.
Install and get data
We install Kornia and some dependencies, and download a simple data sample
import io
import requests
def download_image(url: str, filename: str = "") -> str:
= url.split("/")[-1] if len(filename) == 0 else filename
filename # Download
= io.BytesIO(requests.get(url).content)
bytesio # Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
= "https://github.com/kornia/data/raw/main/causevic16semseg3.png"
url download_image(url)
'causevic16semseg3.png'
# import the libraries
import kornia as K
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
Define Augmentation pipeline
We define a class to define our augmentation API using an nn.Module
class MyAugmentation(nn.Module):
def __init__(self):
super().__init__()
# we define and cache our operators as class members
self.k1 = K.augmentation.ColorJitter(0.15, 0.25, 0.25, 0.25)
self.k2 = K.augmentation.RandomAffine([-45.0, 45.0], [0.0, 0.15], [0.5, 1.5], [0.0, 0.15])
def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# 1. apply color only in image
# 2. apply geometric tranform
= self.k2(self.k1(img))
img_out
# 3. infer geometry params to mask
# TODO: this will change in future so that no need to infer params
= self.k2(mask, self.k2._params)
mask_out
return img_out, mask_out
Load the data and apply the transforms
def load_data(data_path: str) -> torch.Tensor:
= K.io.load_image(data_path, K.io.ImageLoadType.RGB32)[None, ...] # BxCxHxW
data_t: torch.Tensor = data_t[..., :571], data_t[..., 572:]
img, labels return img, labels
# load data (B, C, H, W)
= load_data("causevic16semseg3.png")
img, labels
# create augmentation instance
= MyAugmentation()
aug
# apply the augmenation pipelone to our batch of data
= aug(img, labels)
img_aug, labels_aug
# visualize
= torch.cat([img, labels], dim=-1)
img_out
plt.imshow(K.tensor_to_image(img_out))"off")
plt.axis(
# generate several samples
int = 10
num_samples:
for img_id in range(num_samples):
# generate data
= aug(img, labels)
img_aug, labels_aug = torch.cat([img_aug, labels_aug], dim=-1)
img_out
# save data
plt.figure()
plt.imshow(K.tensor_to_image(img_out))"off")
plt.axis(# plt.savefig(f"img_{img_id}.png", bbox_inches="tight")
plt.show()