Data Augmentation 2D

Basic
2D
Data augmentation
kornia.augmentation
A show case of the Data Augmentation operation available on Kornia for images.
Author

João Gustavo A. Amorim

Published

February 4, 2023

Open in google colab

Just a simple examples showing the Augmentations available on Kornia.

For more information check the docs: https://kornia.readthedocs.io/en/latest/augmentation.module.html

%%capture
!pip install kornia
!pip install kornia-rs
import kornia
import matplotlib.pyplot as plt
from kornia.augmentation import (
    CenterCrop,
    ColorJiggle,
    ColorJitter,
    PadTo,
    RandomAffine,
    RandomBoxBlur,
    RandomBrightness,
    RandomChannelShuffle,
    RandomContrast,
    RandomCrop,
    RandomCutMixV2,
    RandomElasticTransform,
    RandomEqualize,
    RandomErasing,
    RandomFisheye,
    RandomGamma,
    RandomGaussianBlur,
    RandomGaussianNoise,
    RandomGrayscale,
    RandomHorizontalFlip,
    RandomHue,
    RandomInvert,
    RandomJigsaw,
    RandomMixUpV2,
    RandomMosaic,
    RandomMotionBlur,
    RandomPerspective,
    RandomPlanckianJitter,
    RandomPlasmaBrightness,
    RandomPlasmaContrast,
    RandomPlasmaShadow,
    RandomPosterize,
    RandomResizedCrop,
    RandomRGBShift,
    RandomRotation,
    RandomSaturation,
    RandomSharpness,
    RandomSolarize,
    RandomThinPlateSpline,
    RandomVerticalFlip,
)

Load an Image

The augmentations expects an image with shape BxCxHxW

import io

import requests


def download_image(url: str, filename: str = "") -> str:
    filename = url.split("/")[-1] if len(filename) == 0 else filename
    # Download
    bytesio = io.BytesIO(requests.get(url).content)
    # Save file
    with open(filename, "wb") as outfile:
        outfile.write(bytesio.getbuffer())

    return filename


url = "https://raw.githubusercontent.com/kornia/data/main/panda.jpg"
download_image(url)
'panda.jpg'
img_type = kornia.io.ImageLoadType.RGB32
img = kornia.io.load_image("panda.jpg", img_type, "cpu")[None]
def plot_tensor(data, title=""):
    b, c, h, w = data.shape

    fig, axes = plt.subplots(1, b, dpi=150, subplot_kw={"aspect": "equal"})
    if b == 1:
        axes = [axes]

    for idx, ax in enumerate(axes):
        ax.imshow(kornia.utils.tensor_to_image(data[idx, ...]))
        ax.set_ylim(h, 0)
        ax.set_xlim(0, w)
        ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
    fig.suptitle(title)
    plt.show()
plot_tensor(img, "panda")

2D transforms

Sometimes you may wish to apply the exact same transformations on all the elements in one batch. Here, we provided a same_on_batch keyword to all random generators for you to use. Instead of an element-wise parameter generating, it will generate exact same parameters across the whole batch.

# Create a batched input
num_samples = 2

inpt = img.repeat(num_samples, 1, 1, 1)

Intensity

Random Planckian Jitter

randomplanckianjitter = RandomPlanckianJitter("blackbody", same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomplanckianjitter(inpt), "Planckian Jitter")

Random Plasma Shadow

randomplasmashadow = RandomPlasmaShadow(
    roughness=(0.1, 0.7), shade_intensity=(-1.0, 0.0), shade_quantity=(0.0, 1.0), same_on_batch=False, keepdim=False, p=1.0
)

plot_tensor(randomplasmashadow(inpt), "Plasma Shadow")

Random Plasma Brightness

randomplasmabrightness = RandomPlasmaBrightness(
    roughness=(0.1, 0.7), intensity=(0.0, 1.0), same_on_batch=False, keepdim=False, p=1.0
)
plot_tensor(randomplasmabrightness(inpt), "Plasma Brightness")

Random Plasma Contrast

randomplasmacontrast = RandomPlasmaContrast(roughness=(0.1, 0.7), same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomplasmacontrast(inpt), "Plasma Contrast")

Color Jiggle

colorjiggle = ColorJiggle(0.3, 0.3, 0.3, 0.3, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(colorjiggle(inpt), "Color Jiggle")

Color Jitter

colorjitter = ColorJitter(0.3, 0.3, 0.3, 0.3, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(colorjitter(inpt), "Color Jitter")

Random Box Blur

randomboxblur = RandomBoxBlur((21, 5), "reflect", same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomboxblur(inpt), "Box Blur")

Random Brightness

randombrightness = RandomBrightness(brightness=(0.8, 1.2), clip_output=True, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randombrightness(inpt), "Random Brightness")

Random Channel Shuffle

randomchannelshuffle = RandomChannelShuffle(same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomchannelshuffle(inpt), "Random Channel Shuffle")

Random Contrast

randomcontrast = RandomContrast(contrast=(0.8, 1.2), clip_output=True, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomcontrast(inpt), "Random Contrast")

Random Equalize

randomequalize = RandomEqualize(same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomequalize(inpt), "Random Equalize")

Random Gamma

randomgamma = RandomGamma((0.2, 1.3), (1.0, 1.5), same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomgamma(inpt), "Random Gamma")

Random Grayscale

randomgrayscale = RandomGrayscale(same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomgrayscale(inpt), "Random Grayscale")

Random Gaussian Blur

randomgaussianblur = RandomGaussianBlur((21, 21), (0.2, 1.3), "reflect", same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomgaussianblur(inpt), "Random Gaussian Blur")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Random Gaussian Noise

randomgaussiannoise = RandomGaussianNoise(mean=0.2, std=0.7, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomgaussiannoise(inpt), "Random Gaussian Noise")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Random Hue

randomhue = RandomHue((-0.2, 0.4), same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomhue(inpt), "Random Hue")

Random Motion Blur

randommotionblur = RandomMotionBlur((7, 7), 35.0, 0.5, "reflect", "nearest", same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randommotionblur(inpt), "Random Motion Blur")

Random Posterize

randomposterize = RandomPosterize(bits=3, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomposterize(inpt), "Random Posterize")

Random RGB Shift

randomrgbshift = RandomRGBShift(
    r_shift_limit=0.5, g_shift_limit=0.5, b_shift_limit=0.5, same_on_batch=False, keepdim=False, p=1.0
)
plot_tensor(randomrgbshift(inpt), "Random RGB Shift")

Random Saturation

randomsaturation = RandomSaturation((1.0, 1.0), same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomsaturation(inpt), "Random Saturation")

Random Sharpness

randomsharpness = RandomSharpness((0.5, 1.0), same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomsharpness(inpt), "Random Sharpness")

Random Solarize

randomsolarize = RandomSolarize(0.3, 0.1, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomsolarize(inpt), "Random Solarize")

Geometric

Center Crop

centercrop = CenterCrop(150, resample="nearest", cropping_mode="resample", align_corners=True, keepdim=False, p=1.0)

plot_tensor(centercrop(inpt), "Center Crop")

Pad To

padto = PadTo((500, 500), "constant", 1, keepdim=False)

plot_tensor(padto(inpt), "Pad To")

Random Affine

randomaffine = RandomAffine(
    (-15.0, 5.0),
    (0.3, 1.0),
    (0.4, 1.3),
    0.5,
    resample="nearest",
    padding_mode="reflection",
    align_corners=True,
    same_on_batch=False,
    keepdim=False,
    p=1.0,
)
plot_tensor(randomaffine(inpt), "Random Affine")

Random Crop

randomcrop = RandomCrop(
    (150, 150),
    10,
    True,
    1,
    "constant",
    "nearest",
    cropping_mode="resample",
    same_on_batch=False,
    align_corners=True,
    keepdim=False,
    p=1.0,
)

plot_tensor(randomcrop(inpt), "Random Crop")

Random Erasing

randomerasing = RandomErasing(scale=(0.02, 0.33), ratio=(0.3, 3.3), value=1, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomerasing(inpt), "Random Erasing")

Random Elastic Transform

randomelastictransform = RandomElasticTransform(
    (27, 27), (33, 31), (0.5, 1.5), align_corners=True, padding_mode="reflection", same_on_batch=False, keepdim=False, p=1.0
)

plot_tensor(randomelastictransform(inpt), "Random Elastic Transform")

Random Fish Eye

c = kornia.core.tensor([-0.3, 0.3])
g = kornia.core.tensor([0.9, 1.0])
randomfisheye = RandomFisheye(c, c, g, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomfisheye(inpt), "Random Fish Eye")

Random Horizontal Flip

randomhorizontalflip = RandomHorizontalFlip(same_on_batch=False, keepdim=False, p=0.7)

plot_tensor(randomhorizontalflip(inpt), "Random Horizontal Flip")

Random Invert

randominvert = RandomInvert(same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randominvert(inpt), "Random Invert")

Random Perspective

randomperspective = RandomPerspective(0.5, "nearest", align_corners=True, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomperspective(inpt), "Random Perspective")

Random Resized Crop

randomresizedcrop = RandomResizedCrop(
    (200, 200),
    (0.4, 1.0),
    (2.0, 2.0),
    "nearest",
    align_corners=True,
    cropping_mode="resample",
    same_on_batch=False,
    keepdim=False,
    p=1.0,
)

plot_tensor(randomresizedcrop(inpt), "Random Resized Crop")

Random Rotation

randomrotation = RandomRotation(15.0, "nearest", align_corners=True, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomrotation(inpt), "Random Rotation")

Random Vertical Flip

randomverticalflip = RandomVerticalFlip(same_on_batch=False, keepdim=False, p=0.6, p_batch=1.0)

plot_tensor(randomverticalflip(inpt), "Random Vertical Flip")

Random Thin Plate Spline

randomthinplatespline = RandomThinPlateSpline(0.6, align_corners=True, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomverticalflip(inpt), "Random Thin Plate Spline")

Mix

Random Cut Mix

randomcutmixv2 = RandomCutMixV2(4, (0.2, 0.9), 0.1, same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randomcutmixv2(inpt), "Random Cut Mix")

Random Mix Up

randommixupv2 = RandomMixUpV2((0.1, 0.9), same_on_batch=False, keepdim=False, p=1.0)

plot_tensor(randommixupv2(inpt), "Random Mix Up")

Random Mosaic

randommosaic = RandomMosaic(
    (250, 125),
    (4, 4),
    (0.3, 0.7),
    align_corners=True,
    cropping_mode="resample",
    padding_mode="reflect",
    resample="nearest",
    keepdim=False,
    p=1.0,
)
plot_tensor(randommosaic(inpt), "Random Mosaic")

Random Jigsaw

# randomjigsaw = RandomJigsaw((2, 2), ensure_perm=False, same_on_batch=False, keepdim=False, p=1.0)


# plot_tensor(randomjigsaw(inpt), "Random Jigsaw")