%%capture
!pip install kornia
!pip install kornia-rs
Data Augmentation 2D
Basic
2D
Data augmentation
kornia.augmentation
A show case of the Data Augmentation operation available on Kornia for images.
Just a simple examples showing the Augmentations available on Kornia.
For more information check the docs: https://kornia.readthedocs.io/en/latest/augmentation.module.html
import kornia
import matplotlib.pyplot as plt
from kornia.augmentation import (
CenterCrop,
ColorJiggle,
ColorJitter,
PadTo,
RandomAffine,
RandomBoxBlur,
RandomBrightness,
RandomChannelShuffle,
RandomContrast,
RandomCrop,
RandomCutMixV2,
RandomElasticTransform,
RandomEqualize,
RandomErasing,
RandomFisheye,
RandomGamma,
RandomGaussianBlur,
RandomGaussianNoise,
RandomGrayscale,
RandomHorizontalFlip,
RandomHue,
RandomInvert,
RandomJigsaw,
RandomMixUpV2,
RandomMosaic,
RandomMotionBlur,
RandomPerspective,
RandomPlanckianJitter,
RandomPlasmaBrightness,
RandomPlasmaContrast,
RandomPlasmaShadow,
RandomPosterize,
RandomResizedCrop,
RandomRGBShift,
RandomRotation,
RandomSaturation,
RandomSharpness,
RandomSolarize,
RandomThinPlateSpline,
RandomVerticalFlip, )
Load an Image
The augmentations expects an image with shape BxCxHxW
import io
import requests
def download_image(url: str, filename: str = "") -> str:
= url.split("/")[-1] if len(filename) == 0 else filename
filename # Download
= io.BytesIO(requests.get(url).content)
bytesio # Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
= "https://raw.githubusercontent.com/kornia/data/main/panda.jpg"
url download_image(url)
'panda.jpg'
= kornia.io.ImageLoadType.RGB32
img_type = kornia.io.load_image("panda.jpg", img_type, "cpu")[None] img
def plot_tensor(data, title=""):
= data.shape
b, c, h, w
= plt.subplots(1, b, dpi=150, subplot_kw={"aspect": "equal"})
fig, axes if b == 1:
= [axes]
axes
for idx, ax in enumerate(axes):
ax.imshow(kornia.utils.tensor_to_image(data[idx, ...]))0)
ax.set_ylim(h, 0, w)
ax.set_xlim(=True, labeltop=True, bottom=False, labelbottom=False)
ax.tick_params(top
fig.suptitle(title) plt.show()
"panda") plot_tensor(img,
2D transforms
Sometimes you may wish to apply the exact same transformations on all the elements in one batch. Here, we provided a same_on_batch
keyword to all random generators for you to use. Instead of an element-wise parameter generating, it will generate exact same parameters across the whole batch.
# Create a batched input
= 2
num_samples
= img.repeat(num_samples, 1, 1, 1) inpt
Intensity
Random Planckian Jitter
= RandomPlanckianJitter("blackbody", same_on_batch=False, keepdim=False, p=1.0)
randomplanckianjitter
"Planckian Jitter") plot_tensor(randomplanckianjitter(inpt),
Random Plasma Shadow
= RandomPlasmaShadow(
randomplasmashadow =(0.1, 0.7), shade_intensity=(-1.0, 0.0), shade_quantity=(0.0, 1.0), same_on_batch=False, keepdim=False, p=1.0
roughness
)
"Plasma Shadow") plot_tensor(randomplasmashadow(inpt),
Random Plasma Brightness
= RandomPlasmaBrightness(
randomplasmabrightness =(0.1, 0.7), intensity=(0.0, 1.0), same_on_batch=False, keepdim=False, p=1.0
roughness
)"Plasma Brightness") plot_tensor(randomplasmabrightness(inpt),
Random Plasma Contrast
= RandomPlasmaContrast(roughness=(0.1, 0.7), same_on_batch=False, keepdim=False, p=1.0)
randomplasmacontrast
"Plasma Contrast") plot_tensor(randomplasmacontrast(inpt),
Color Jiggle
= ColorJiggle(0.3, 0.3, 0.3, 0.3, same_on_batch=False, keepdim=False, p=1.0)
colorjiggle
"Color Jiggle") plot_tensor(colorjiggle(inpt),
Color Jitter
= ColorJitter(0.3, 0.3, 0.3, 0.3, same_on_batch=False, keepdim=False, p=1.0)
colorjitter
"Color Jitter") plot_tensor(colorjitter(inpt),
Random Box Blur
= RandomBoxBlur((21, 5), "reflect", same_on_batch=False, keepdim=False, p=1.0)
randomboxblur
"Box Blur") plot_tensor(randomboxblur(inpt),
Random Brightness
= RandomBrightness(brightness=(0.8, 1.2), clip_output=True, same_on_batch=False, keepdim=False, p=1.0)
randombrightness
"Random Brightness") plot_tensor(randombrightness(inpt),
Random Channel Shuffle
= RandomChannelShuffle(same_on_batch=False, keepdim=False, p=1.0)
randomchannelshuffle
"Random Channel Shuffle") plot_tensor(randomchannelshuffle(inpt),
Random Contrast
= RandomContrast(contrast=(0.8, 1.2), clip_output=True, same_on_batch=False, keepdim=False, p=1.0)
randomcontrast
"Random Contrast") plot_tensor(randomcontrast(inpt),
Random Equalize
= RandomEqualize(same_on_batch=False, keepdim=False, p=1.0)
randomequalize
"Random Equalize") plot_tensor(randomequalize(inpt),
Random Gamma
= RandomGamma((0.2, 1.3), (1.0, 1.5), same_on_batch=False, keepdim=False, p=1.0)
randomgamma
"Random Gamma") plot_tensor(randomgamma(inpt),
Random Grayscale
= RandomGrayscale(same_on_batch=False, keepdim=False, p=1.0)
randomgrayscale
"Random Grayscale") plot_tensor(randomgrayscale(inpt),
Random Gaussian Blur
= RandomGaussianBlur((21, 21), (0.2, 1.3), "reflect", same_on_batch=False, keepdim=False, p=1.0)
randomgaussianblur
"Random Gaussian Blur") plot_tensor(randomgaussianblur(inpt),
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Random Gaussian Noise
= RandomGaussianNoise(mean=0.2, std=0.7, same_on_batch=False, keepdim=False, p=1.0)
randomgaussiannoise
"Random Gaussian Noise") plot_tensor(randomgaussiannoise(inpt),
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Random Hue
= RandomHue((-0.2, 0.4), same_on_batch=False, keepdim=False, p=1.0)
randomhue
"Random Hue") plot_tensor(randomhue(inpt),
Random Motion Blur
= RandomMotionBlur((7, 7), 35.0, 0.5, "reflect", "nearest", same_on_batch=False, keepdim=False, p=1.0)
randommotionblur
"Random Motion Blur") plot_tensor(randommotionblur(inpt),
Random Posterize
= RandomPosterize(bits=3, same_on_batch=False, keepdim=False, p=1.0)
randomposterize
"Random Posterize") plot_tensor(randomposterize(inpt),
Random RGB Shift
= RandomRGBShift(
randomrgbshift =0.5, g_shift_limit=0.5, b_shift_limit=0.5, same_on_batch=False, keepdim=False, p=1.0
r_shift_limit
)"Random RGB Shift") plot_tensor(randomrgbshift(inpt),
Random Saturation
= RandomSaturation((1.0, 1.0), same_on_batch=False, keepdim=False, p=1.0)
randomsaturation
"Random Saturation") plot_tensor(randomsaturation(inpt),
Random Sharpness
= RandomSharpness((0.5, 1.0), same_on_batch=False, keepdim=False, p=1.0)
randomsharpness
"Random Sharpness") plot_tensor(randomsharpness(inpt),
Random Solarize
= RandomSolarize(0.3, 0.1, same_on_batch=False, keepdim=False, p=1.0)
randomsolarize
"Random Solarize") plot_tensor(randomsolarize(inpt),
Geometric
Center Crop
= CenterCrop(150, resample="nearest", cropping_mode="resample", align_corners=True, keepdim=False, p=1.0)
centercrop
"Center Crop") plot_tensor(centercrop(inpt),
Pad To
= PadTo((500, 500), "constant", 1, keepdim=False)
padto
"Pad To") plot_tensor(padto(inpt),
Random Affine
= RandomAffine(
randomaffine -15.0, 5.0),
(0.3, 1.0),
(0.4, 1.3),
(0.5,
="nearest",
resample="reflection",
padding_mode=True,
align_corners=False,
same_on_batch=False,
keepdim=1.0,
p
)"Random Affine") plot_tensor(randomaffine(inpt),
Random Crop
= RandomCrop(
randomcrop 150, 150),
(10,
True,
1,
"constant",
"nearest",
="resample",
cropping_mode=False,
same_on_batch=True,
align_corners=False,
keepdim=1.0,
p
)
"Random Crop") plot_tensor(randomcrop(inpt),
Random Erasing
= RandomErasing(scale=(0.02, 0.33), ratio=(0.3, 3.3), value=1, same_on_batch=False, keepdim=False, p=1.0)
randomerasing
"Random Erasing") plot_tensor(randomerasing(inpt),
Random Elastic Transform
= RandomElasticTransform(
randomelastictransform 27, 27), (33, 31), (0.5, 1.5), align_corners=True, padding_mode="reflection", same_on_batch=False, keepdim=False, p=1.0
(
)
"Random Elastic Transform") plot_tensor(randomelastictransform(inpt),
Random Fish Eye
= kornia.core.tensor([-0.3, 0.3])
c = kornia.core.tensor([0.9, 1.0])
g = RandomFisheye(c, c, g, same_on_batch=False, keepdim=False, p=1.0)
randomfisheye
"Random Fish Eye") plot_tensor(randomfisheye(inpt),
Random Horizontal Flip
= RandomHorizontalFlip(same_on_batch=False, keepdim=False, p=0.7)
randomhorizontalflip
"Random Horizontal Flip") plot_tensor(randomhorizontalflip(inpt),
Random Invert
= RandomInvert(same_on_batch=False, keepdim=False, p=1.0)
randominvert
"Random Invert") plot_tensor(randominvert(inpt),
Random Perspective
= RandomPerspective(0.5, "nearest", align_corners=True, same_on_batch=False, keepdim=False, p=1.0)
randomperspective
"Random Perspective") plot_tensor(randomperspective(inpt),
Random Resized Crop
= RandomResizedCrop(
randomresizedcrop 200, 200),
(0.4, 1.0),
(2.0, 2.0),
("nearest",
=True,
align_corners="resample",
cropping_mode=False,
same_on_batch=False,
keepdim=1.0,
p
)
"Random Resized Crop") plot_tensor(randomresizedcrop(inpt),
Random Rotation
= RandomRotation(15.0, "nearest", align_corners=True, same_on_batch=False, keepdim=False, p=1.0)
randomrotation
"Random Rotation") plot_tensor(randomrotation(inpt),
Random Vertical Flip
= RandomVerticalFlip(same_on_batch=False, keepdim=False, p=0.6, p_batch=1.0)
randomverticalflip
"Random Vertical Flip") plot_tensor(randomverticalflip(inpt),
Random Thin Plate Spline
= RandomThinPlateSpline(0.6, align_corners=True, same_on_batch=False, keepdim=False, p=1.0)
randomthinplatespline
"Random Thin Plate Spline") plot_tensor(randomverticalflip(inpt),
Mix
Random Cut Mix
= RandomCutMixV2(4, (0.2, 0.9), 0.1, same_on_batch=False, keepdim=False, p=1.0)
randomcutmixv2
"Random Cut Mix") plot_tensor(randomcutmixv2(inpt),
Random Mix Up
= RandomMixUpV2((0.1, 0.9), same_on_batch=False, keepdim=False, p=1.0)
randommixupv2
"Random Mix Up") plot_tensor(randommixupv2(inpt),
Random Mosaic
= RandomMosaic(
randommosaic 250, 125),
(4, 4),
(0.3, 0.7),
(=True,
align_corners="resample",
cropping_mode="reflect",
padding_mode="nearest",
resample=False,
keepdim=1.0,
p
)"Random Mosaic") plot_tensor(randommosaic(inpt),
Random Jigsaw
# randomjigsaw = RandomJigsaw((2, 2), ensure_perm=False, same_on_batch=False, keepdim=False, p=1.0)
# plot_tensor(randomjigsaw(inpt), "Random Jigsaw")