%%capture
%matplotlib inline
# Install latest kornia
!pip install kornia
!pip install kornia-rs
Extracting and Combining Tensor Patches
Basic
Patches
kornia.contrib
In this tutorial we will show how you can extract and combine tensor patches using kornia
Install and get data
import io
import requests
def download_image(url: str, filename: str = "") -> str:
= url.split("/")[-1] if len(filename) == 0 else filename
filename # Download
= io.BytesIO(requests.get(url).content)
bytesio # Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
= "https://raw.githubusercontent.com/kornia/data/main/panda.jpg"
url download_image(url)
'panda.jpg'
import kornia as K
import matplotlib.pyplot as plt
import torch
from kornia.contrib import (
CombineTensorPatches,
ExtractTensorPatches,
combine_tensor_patches,
compute_padding,
extract_tensor_patches, )
Using Modules
= 8, 8
h, w = 4
win = 2
pad
= torch.randn(2, 3, h, w)
image print(image.shape)
= ExtractTensorPatches(window_size=win, stride=win, padding=pad)
tiler = CombineTensorPatches(original_size=(h, w), window_size=win, stride=win, unpadding=pad)
merger = tiler(image)
image_tiles print(image_tiles.shape)
= merger(image_tiles)
new_image print(new_image.shape)
assert torch.allclose(image, new_image)
torch.Size([2, 3, 8, 8])
torch.Size([2, 9, 3, 4, 4])
torch.Size([2, 3, 8, 8])
Using Functions
= 8, 8
h, w = 4
win = 2
pad
= torch.randn(1, 1, h, w)
image print(image.shape)
= extract_tensor_patches(image, window_size=win, stride=win, padding=pad)
patches print(patches.shape)
= combine_tensor_patches(patches, original_size=(h, w), window_size=win, stride=win, unpadding=pad)
restored_img print(restored_img.shape)
assert torch.allclose(image, restored_img)
torch.Size([1, 1, 8, 8])
torch.Size([1, 9, 1, 4, 4])
torch.Size([1, 1, 8, 8])
Padding
All parameters of extract and combine functions accept a single int or tuple of two ints. Since padding is an integral part of these functions, it’s important to note the following:
- If padding is
p
-> it means both height and width are padded by2*p
- If padding is
(ph, pw)
-> it means height is padded by2*ph
and width is padded by2*pw
It is recommended to use the existing function compute_padding
to ensure the required padding is added.
Examples
def extract_and_combine(image, window_size, padding):
= image.shape[-2:]
h, w = ExtractTensorPatches(window_size=window_size, stride=window_size, padding=padding)
tiler = CombineTensorPatches(original_size=(h, w), window_size=window_size, stride=window_size, unpadding=padding)
merger = tiler(image)
image_tiles print(f"Shape of tensor patches = {image_tiles.shape}")
= merger(image_tiles)
merged_image print(f"Shape of merged image = {merged_image.shape}")
assert torch.allclose(image, merged_image)
return merged_image
= torch.randn(2, 3, 9, 9)
image = extract_and_combine(image, window_size=(4, 4), padding=(2, 2)) _
Shape of tensor patches = torch.Size([2, 9, 3, 4, 4])
Shape of merged image = torch.Size([2, 3, 9, 9])
These functions also work with rectangular images
= torch.randn(1, 1, 8, 6)
rect_image print(rect_image.shape)
torch.Size([1, 1, 8, 6])
= extract_and_combine(rect_image, window_size=(4, 4), padding=compute_padding((8, 6), 4)) restored_image
Shape of tensor patches = torch.Size([1, 4, 1, 4, 4])
Shape of merged image = torch.Size([1, 1, 8, 6])
Recall that when padding is a tuple of ints (ph, pw)
, the height and width are padded by 2*ph
and 2*pw
respectively.
# Confirm that the original image and restored image are the same
assert (restored_image == rect_image).all()
Let’s now visualize how extraction and combining works.
# Load sample image
= K.io.load_image("panda.jpg", K.io.ImageLoadType.RGB32)[None, ...] # BxCxHxW
img_tensor = img_tensor.shape[-2:]
h, w print(f"Shape of image = {img_tensor.shape}")
"off")
plt.axis(
plt.imshow(K.tensor_to_image(img_tensor)) plt.show()
Shape of image = torch.Size([1, 3, 510, 1020])
We will use window_size = (400, 400)
with stride = 200
to extract 15 overlapping tiles of shape (400, 400)
and visualize them.
# Set window size
= 400
win # Set stride
= 200
stride # Calculate required padding
= compute_padding(original_size=(510, 1020), window_size=win)
pad
= ExtractTensorPatches(window_size=win, stride=stride, padding=pad)
tiler = tiler(img_tensor)
image_tiles print(f"Shape of image tiles = {image_tiles.shape}")
Shape of image tiles = torch.Size([1, 15, 3, 400, 400])
# Create the plot
= plt.subplots(5, 3, figsize=(8, 8))
fig, axs = axs.ravel()
axs
for i in range(len(image_tiles[0])):
"off")
axs[i].axis(0][i]))
axs[i].imshow(K.tensor_to_image(image_tiles[
plt.show()
Finally, let’s combine the patches and visualize the resulting image
= CombineTensorPatches(original_size=(h, w), window_size=win, stride=stride, unpadding=pad)
merger = merger(image_tiles)
merged_image print(f"Shape of restored image = {merged_image.shape}")
0]))
plt.imshow(K.tensor_to_image(merged_image["off")
plt.axis( plt.show()
Shape of restored image = torch.Size([1, 3, 510, 1020])