%%capture
!pip install kornia
!pip install kornia-rs
!pip install kornia_moons --no-deps
!pip install opencv-python --upgrade
Image matching example with DISK local features
Intermediate
DISK
LAF
Image matching
kornia.feature
In this tutorial we are going to show how to perform image matching using a DISK algorithm
First, we will install everything needed:
- fresh version of kornia for DISK
- fresh version of OpenCV for MAGSAC++ geometry estimation
- kornia_moons for the conversions and visualization
Docs: kornia.feature.DISK
Now let’s download an image pair
import io
import requests
def download_image(url: str, filename: str = "") -> str:
= url.split("/")[-1] if len(filename) == 0 else filename
filename # Download
= io.BytesIO(requests.get(url).content)
bytesio # Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
= "https://github.com/kornia/data/raw/main/matching/kn_church-2.jpg"
url_a = "https://github.com/kornia/data/raw/main/matching/kn_church-8.jpg"
url_b
download_image(url_a) download_image(url_b)
'kn_church-8.jpg'
First, imports.
import cv2
import kornia as K
import kornia.feature as KF
import matplotlib.pyplot as plt
import numpy as np
import torch
from kornia.feature.adalam import AdalamFilter
from kornia_moons.viz import *
= K.utils.get_cuda_or_mps_device_if_available()
device print(device)
cuda:0
# %%capture
= "kn_church-2.jpg"
fname1 = "kn_church-8.jpg"
fname2
= KF.adalam.get_adalam_default_config()
adalam_config # adalam_config['orientation_difference_threshold'] = None
# adalam_config['scale_rate_threshold'] = None
"force_seed_mnn"] = False
adalam_config["search_expansion"] = 16
adalam_config["ransac_iters"] = 256
adalam_config[
= K.io.load_image(fname1, K.io.ImageLoadType.RGB32, device=device)[None, ...]
img1 = K.io.load_image(fname2, K.io.ImageLoadType.RGB32, device=device)[None, ...]
img2
= 2048
num_features = KF.DISK.from_pretrained("depth").to(device)
disk
= torch.tensor(img1.shape[2:], device=device)
hw1 = torch.tensor(img2.shape[2:], device=device)
hw2
= True
match_with_adalam
with torch.inference_mode():
= torch.cat([img1, img2], dim=0)
inp = disk(inp, num_features, pad_if_not_divisible=True)
features1, features2 = features1.keypoints, features1.descriptors
kps1, descs1 = features2.keypoints, features2.descriptors
kps2, descs2 if match_with_adalam:
= KF.laf_from_center_scale_ori(kps1[None], 96 * torch.ones(1, len(kps1), 1, 1, device=device))
lafs1 = KF.laf_from_center_scale_ori(kps2[None], 96 * torch.ones(1, len(kps2), 1, 1, device=device))
lafs2
= KF.match_adalam(descs1, descs2, lafs1, lafs2, hw1=hw1, hw2=hw2, config=adalam_config)
dists, idxs else:
= KF.match_smnn(descs1, descs2, 0.98)
dists, idxs
print(f"{idxs.shape[0]} tentative matches with DISK AdaLAM")
222 tentative matches with DISK AdaLAM
def get_matching_keypoints(kp1, kp2, idxs):
= kp1[idxs[:, 0]]
mkpts1 = kp2[idxs[:, 1]]
mkpts2 return mkpts1, mkpts2
= get_matching_keypoints(kps1, kps2, idxs)
mkpts1, mkpts2
= cv2.findFundamentalMat(
Fm, inliers 1.0, 0.999, 100000
mkpts1.detach().cpu().numpy(), mkpts2.detach().cpu().numpy(), cv2.USAC_MAGSAC,
)= inliers > 0
inliers print(f"{inliers.sum()} inliers with DISK")
103 inliers with DISK
Let’s draw the inliers in green and tentative correspondences in yellow
draw_LAF_matches(None].cpu()),
KF.laf_from_center_scale_ori(kps1[None].cpu()),
KF.laf_from_center_scale_ori(kps2[
idxs.cpu(),
K.tensor_to_image(img1.cpu()),
K.tensor_to_image(img2.cpu()),
inliers,={"inlier_color": (0.2, 1, 0.2), "tentative_color": (1, 1, 0.2, 0.3), "feature_color": None, "vertical": False},
draw_dict )