Image stitching example with LoFTR

Intermediate
LoFTR
kornia.feature
A show case of how to do image stitching using LoFTR from Kornia.
Author

Edgar Riba

Published

November 19, 2021

Open in google colab

Open in HF Spaces

First, we will install everything needed:

%%capture
!pip install kornia
!pip install kornia-rs

Now let’s download an image pair

import io

import requests


def download_image(url: str, filename: str = "") -> str:
    filename = url.split("/")[-1].split("?")[0] if len(filename) == 0 else filename
    # Download
    bytesio = io.BytesIO(requests.get(url).content)
    # Save file
    with open(filename, "wb") as outfile:
        outfile.write(bytesio.getbuffer())

    return filename


download_image("https://github.com/kornia/data/raw/main/prtn00.jpg")
download_image("https://github.com/kornia/data/raw/main/prtn01.jpg")
'prtn01.jpg'
%%capture
import kornia as K
import kornia.feature as KF
import matplotlib.pyplot as plt
import numpy as np
import torch


def load_images(fnames):
    return [K.io.load_image(fn, K.io.ImageLoadType.RGB32)[None, ...] for fn in fnames]


imgs = load_images(["prtn00.jpg", "prtn01.jpg"])

Stitch them together

from kornia.contrib import ImageStitcher

IS = ImageStitcher(KF.LoFTR(pretrained="outdoor"), estimator="ransac")

with torch.no_grad():
    out = IS(*imgs)

plt.imshow(K.tensor_to_image(out))
plt.show()

Another example

download_image("https://github.com/daeyun/Image-Stitching/blob/master/img/hill/1.JPG?raw=true")
download_image("https://github.com/daeyun/Image-Stitching/blob/master/img/hill/2.JPG?raw=true")
download_image("https://github.com/daeyun/Image-Stitching/blob/master/img/hill/3.JPG?raw=true")
'3.JPG'
imgs = load_images(["1.JPG", "2.JPG", "3.JPG"])
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

axarr[0].imshow(K.tensor_to_image(imgs[0]))
axarr[0].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
axarr[1].imshow(K.tensor_to_image(imgs[1]))
axarr[1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
axarr[2].imshow(K.tensor_to_image(imgs[2]))
axarr[2].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

matcher = KF.LocalFeatureMatcher(KF.GFTTAffNetHardNet(100), KF.DescriptorMatcher("snn", 0.8))
IS = ImageStitcher(matcher, estimator="ransac")

with torch.no_grad():
    out = IS(*imgs)
plt.figure(figsize=(16, 16))
plt.imshow(K.tensor_to_image(out))