Image matching example with KeyNet-AdaLAM

Intermediate
KeyNet
LAF
Adalam
Image matching
kornia.feature
In this tutorial we are going to show how to perform image matching using a KeyNet-Adalam Algorithm
Author

Dmytro Mishkin

Published

August 22, 2022

Open in google colab

First, we will install everything needed:

Docs: match_adalam

%%capture
!pip install kornia
!pip install kornia-rs
!pip install kornia_moons
!pip install opencv-python --upgrade

Now let’s download an image pair

import io

import requests


def download_image(url: str, filename: str = "") -> str:
    filename = url.split("/")[-1] if len(filename) == 0 else filename
    # Download
    bytesio = io.BytesIO(requests.get(url).content)
    # Save file
    with open(filename, "wb") as outfile:
        outfile.write(bytesio.getbuffer())

    return filename


url_a = "https://github.com/kornia/data/raw/main/matching/kn_church-2.jpg"
url_b = "https://github.com/kornia/data/raw/main/matching/kn_church-8.jpg"
download_image(url_a)
download_image(url_b)
'kn_church-8.jpg'

First, imports.

import cv2
import kornia as K
import kornia.feature as KF
import matplotlib.pyplot as plt
import numpy as np
import torch
from kornia_moons.viz import *

# device = K.utils.get_cuda_or_mps_device_if_available()
device = torch.device("cpu")
%%capture
fname1 = "kn_church-2.jpg"
fname2 = "kn_church-8.jpg"

img1 = K.io.load_image(fname1, K.io.ImageLoadType.RGB32, device=device)[None, ...]
img2 = K.io.load_image(fname2, K.io.ImageLoadType.RGB32, device=device)[None, ...]

feature = KF.KeyNetAffNetHardNet(5000, True).eval().to(device)

input_dict = {
    "image0": K.color.rgb_to_grayscale(img1),  # LofTR works on grayscale images only
    "image1": K.color.rgb_to_grayscale(img2),
}

hw1 = torch.tensor(img1.shape[2:])
hw2 = torch.tensor(img1.shape[2:])

adalam_config = {"device": device}

with torch.inference_mode():
    lafs1, resps1, descs1 = feature(K.color.rgb_to_grayscale(img1))
    lafs2, resps2, descs2 = feature(K.color.rgb_to_grayscale(img2))
    dists, idxs = KF.match_adalam(
        descs1.squeeze(0),
        descs2.squeeze(0),
        lafs1,
        lafs2,  # Adalam takes into account also geometric information
        config=adalam_config,
        hw1=hw1,
        hw2=hw2,  # Adalam also benefits from knowing image size
    )
print(f"{idxs.shape[0]} tentative matches with AdaLAM")
405 tentative matches with AdaLAM
def get_matching_keypoints(lafs1, lafs2, idxs):
    mkpts1 = KF.get_laf_center(lafs1).squeeze()[idxs[:, 0]].detach().cpu().numpy()
    mkpts2 = KF.get_laf_center(lafs2).squeeze()[idxs[:, 1]].detach().cpu().numpy()
    return mkpts1, mkpts2


mkpts1, mkpts2 = get_matching_keypoints(lafs1, lafs2, idxs)

Fm, inliers = cv2.findFundamentalMat(mkpts1, mkpts2, cv2.USAC_MAGSAC, 0.75, 0.999, 100000)
inliers = inliers > 0
print(f"{inliers.sum()} inliers with AdaLAM")
195 inliers with AdaLAM

Let’s draw the inliers in green and tentative correspondences in yellow

draw_LAF_matches(
    lafs1,
    lafs2,
    idxs,
    K.tensor_to_image(img1),
    K.tensor_to_image(img2),
    inliers,
    draw_dict={"inlier_color": (0.2, 1, 0.2), "tentative_color": (1, 1, 0.2, 0.3), "feature_color": None, "vertical": False},
)