%%capture
!pip install kornia
!pip install kornia-rs
Image anti-alias with local features
Basic
HardNet
Patches
Local features
kornia.feature
In this example we will show the benefits of using anti-aliased patch extraction with kornia.
import io
import requests
def download_image(url: str, filename: str = "") -> str:
= url.split("/")[-1] if len(filename) == 0 else filename
filename # Download
= io.BytesIO(requests.get(url).content)
bytesio # Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
= "https://github.com/kornia/data/raw/main/drslump.jpg"
url download_image(url)
'drslump.jpg'
First, lets load some image.
%matplotlib inline
import kornia as K
import kornia.feature as KF
import matplotlib.pyplot as plt
import torch
= torch.device("cpu")
device
= K.io.load_image("drslump.jpg", K.io.ImageLoadType.RGB32, device=device)[None, ...]
img_original
plt.figure() plt.imshow(K.tensor_to_image(img_original))
= img_original.shape
B, CH, H, W
= 4
DOWNSAMPLE = K.geometry.resize(img_original, (H // DOWNSAMPLE, W // DOWNSAMPLE), interpolation="area")
img_small
plt.figure() plt.imshow(K.tensor_to_image(img_small))
Now, lets define a keypoint with a large support region.
def show_lafs(img, lafs, idx=0, color="r", figsize=(10, 7)):
= KF.laf.get_laf_pts_to_draw(lafs, idx)
x, y =figsize)
plt.figure(figsizeif isinstance(img, torch.Tensor):
= K.tensor_to_image(img)
img_show else:
= img
img_show
plt.imshow(img_show)
plt.plot(x, y, color)return
= torch.tensor([[150.0, 0, 180], [0, 150, 280]]).float().view(1, 1, 2, 3)
laf_orig = laf_orig / float(DOWNSAMPLE)
laf_small
=(6, 4))
show_lafs(img_original, laf_orig, figsize=(6, 4)) show_lafs(img_small, laf_small, figsize
Now lets compare how extracted patch would look like when extracted in a naive way and from scale pyramid.
= 32
PS with torch.no_grad():
= KF.extract_patches_from_pyramid(img_original, laf_orig.to(device), PS)
patches_pyr_orig = KF.extract_patches_simple(img_original, laf_orig.to(device), PS)
patches_simple_orig
= KF.extract_patches_from_pyramid(img_small, laf_small.to(device), PS)
patches_pyr_small = KF.extract_patches_simple(img_small, laf_small.to(device), PS)
patches_simple_small
# Now we will glue all the patches together:
def vert_cat_with_margin(p1, p2, margin=3):
= p1.size()
b, n, ch, h, w return torch.cat([p1, torch.ones(b, n, ch, h, margin).to(device), p2], dim=4)
def horiz_cat_with_margin(p1, p2, margin=3):
= p1.size()
b, n, ch, h, w return torch.cat([p1, torch.ones(b, n, ch, margin, w).to(device), p2], dim=3)
= vert_cat_with_margin(patches_pyr_orig, patches_pyr_small)
patches_pyr = vert_cat_with_margin(patches_simple_orig, patches_simple_small)
patches_naive
= horiz_cat_with_margin(patches_naive, patches_pyr) patches_all
Now lets show the result. Top row is what you get if you are extracting patches without any antialiasing - note how the patches extracted from the images of different sizes differ.
Bottom row is patches, which are extracted from images of different sizes using a scale pyramid. They are not yet exactly the same, but the difference is much smaller.
=(10, 10))
plt.figure(figsize0, 0])) plt.imshow(K.tensor_to_image(patches_all[
Lets check how much it influences local descriptor performance such as HardNet
= KF.HardNet(True).eval()
hardnet = (
all_patches =0)
torch.cat([patches_pyr_orig, patches_pyr_small, patches_simple_orig, patches_simple_small], dim1)
.squeeze(=1, keepdim=True)
.mean(dim
)with torch.no_grad():
= hardnet(all_patches)
descs = torch.cdist(descs, descs)
distances print(distances.cpu().detach().numpy())
[[0. 0.16867691 0.8070452 0.52112377]
[0.16867691 0. 0.7973113 0.48472866]
[0.8070452 0.7973113 0. 0.59267515]
[0.52112377 0.48472866 0.59267515 0. ]]
So the descriptor difference between antialiased patches is 0.09 and between naively extracted – 0.44