Data Augmentation Semantic Segmentation

Basic
2D
Segmentation
Data augmentation
kornia.augmentation
In this tutorial we will show how we can quickly perform data augmentation for semantic segmentation using the kornia.augmentation API.
Author

Edgar Riba

Published

March 27, 2021

Open in google colab

Install and get data

We install Kornia and some dependencies, and download a simple data sample

%%capture
%matplotlib inline
!pip install kornia
!pip install kornia-rs
!pip install opencv-python
import io

import requests


def download_image(url: str, filename: str = "") -> str:
    filename = url.split("/")[-1] if len(filename) == 0 else filename
    # Download
    bytesio = io.BytesIO(requests.get(url).content)
    # Save file
    with open(filename, "wb") as outfile:
        outfile.write(bytesio.getbuffer())

    return filename


url = "https://github.com/kornia/data/raw/main/causevic16semseg3.png"
download_image(url)
'causevic16semseg3.png'
# import the libraries
import kornia as K
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

Define Augmentation pipeline

We define a class to define our augmentation API using an nn.Module

class MyAugmentation(nn.Module):
    def __init__(self):
        super().__init__()
        # we define and cache our operators as class members
        self.k1 = K.augmentation.ColorJitter(0.15, 0.25, 0.25, 0.25)
        self.k2 = K.augmentation.RandomAffine([-45.0, 45.0], [0.0, 0.15], [0.5, 1.5], [0.0, 0.15])

    def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # 1. apply color only in image
        # 2. apply geometric tranform
        img_out = self.k2(self.k1(img))

        # 3. infer geometry params to mask
        # TODO: this will change in future so that no need to infer params
        mask_out = self.k2(mask, self.k2._params)

        return img_out, mask_out

Load the data and apply the transforms

def load_data(data_path: str) -> torch.Tensor:
    data_t: torch.Tensor = K.io.load_image(data_path, K.io.ImageLoadType.RGB32)[None, ...]  # BxCxHxW
    img, labels = data_t[..., :571], data_t[..., 572:]
    return img, labels


# load data (B, C, H, W)
img, labels = load_data("causevic16semseg3.png")

# create augmentation instance
aug = MyAugmentation()

# apply the augmenation pipelone to our batch of data
img_aug, labels_aug = aug(img, labels)

# visualize
img_out = torch.cat([img, labels], dim=-1)
plt.imshow(K.tensor_to_image(img_out))
plt.axis("off")

# generate several samples
num_samples: int = 10

for img_id in range(num_samples):
    # generate data
    img_aug, labels_aug = aug(img, labels)
    img_out = torch.cat([img_aug, labels_aug], dim=-1)

    # save data
    plt.figure()
    plt.imshow(K.tensor_to_image(img_out))
    plt.axis("off")
    # plt.savefig(f"img_{img_id}.png", bbox_inches="tight")
    plt.show()