%%capture
!pip install kornia
!pip install kornia-rs
Image Registration by Direct Optimization
The images are courtesy of Dennis Sakva
import io
import requests
def download_data(url: str, filename: str = "") -> str:
= url.split("/")[-1] if len(filename) == 0 else filename
filename # Download
= io.BytesIO(requests.get(url).content)
bytesio # Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
"http://cmp.felk.cvut.cz/~mishkdmy/bee.zip") download_data(
'bee.zip'
%%capture
!unzip bee.zip
Import needed libraries
import os
from copy import deepcopy
from typing import List
import imageio
import kornia as K
import kornia.geometry as KG
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
def get_data_directory(base):
= os.path.join("../", base)
path if os.path.isdir(os.path.join(path, "data")):
return os.path.join(path, "data/")
return get_data_directory(path)
Images preview
Let’s check our images. There are almost 100 of them, so we will show only each 10th
= os.listdir("bee")
fnames = [f"bee/{x}" for x in sorted(fnames) if x.endswith("JPG")]
fnames = plt.subplots(2, 5, figsize=(12, 4), sharex="all", sharey="all", frameon=False)
fig, axis for i, fname in enumerate(fnames):
if i % 10 != 0:
continue
= i // 10
j = K.io.load_image(fname, K.io.ImageLoadType.RGB8)
img // 5][j % 5].imshow(K.tensor_to_image(img), aspect="auto")
axis[j =0.05, hspace=0.05)
plt.subplots_adjust(wspace fig.tight_layout()
So the focus goes from back to the front, so we have to match and merge them in the same order.
Image registration
We will need ImageRegistrator
object to do the matching. Because the photos are takes so that only slight rotation, shift and scale change is possible, we will use similarity
mode, which does exactly this.
bool = torch.cuda.is_available()
use_cuda: = torch.device("cuda" if use_cuda else "cpu")
device = KG.ImageRegistrator("similarity", loss_fn=F.mse_loss, lr=8e-4, pyramid_levels=3, num_iterations=500).to(device)
registrator print(device)
cuda
We will register images sequentially with ImageRegistrator
.
%%capture
= []
models for i, fname in tqdm(enumerate(fnames)):
if i == 0:
continue
= K.io.load_image(fnames[i - 1], K.io.ImageLoadType.RGB32, device=device)[None, ...]
prev_img = K.io.load_image(fname, K.io.ImageLoadType.RGB32, device=device)[None, ...]
curr_img = registrator.register(prev_img, curr_img)
model models.append(deepcopy(model.detach()))
Let’s take the final (the most close-focused) image as the reference - this means that we have to convert our image transforms from (between i
and i+1
) mode into (between i
and last). We can do it by matrix multiplication.
= [torch.eye(3, device=device)[None]]
models_to_final for m in models[::-1]:
@ models_to_final[-1])
models_to_final.append(m = models_to_final[::-1] models_to_final
Let’s check what do we got.
= plt.subplots(2, 5, figsize=(12, 4), sharex="all", sharey="all", frameon=False)
fig, axis for i, fname in enumerate(fnames):
if i % 10 != 0:
continue
= K.io.load_image(fname, K.io.ImageLoadType.RGB32, device=device)[None, ...]
timg = i // 10
j = KG.homography_warp(timg, models_to_final[i], timg.shape[-2:])
timg_dst // 5][j % 5].imshow(K.tensor_to_image(timg_dst * 255.0).astype(np.uint8), aspect="auto")
axis[j =0.05, hspace=0.05)
plt.subplots_adjust(wspace fig.tight_layout()
Finally we will merge the image sequence into single image. The idea is to detect the image parts, which are in focus from the current image and blend them into the final images. To get the sharp image part we can use kornia.filters.laplacian
. Then we reproject image1 into image2, and merge them using mask we created.
def merge_sharp1_into2(timg1, timg2, trans1to2, verbose=False):
= timg2.clone()
curr_img = KG.homography_warp(timg1, torch.inverse(trans1to2), timg.shape[-2:])
warped = K.filters.laplacian(K.color.rgb_to_grayscale(timg1), 7).abs()
mask1 = (mask1 - mask1.min()) / (mask1.max() - mask1.min())
mask1_norm = K.filters.gaussian_blur2d(mask1_norm, (9, 9), (1.6, 1.6))
mask1_blur = mask1_blur / mask1_blur.max()
mask1_blur = KG.homography_warp(mask1_blur.float(), torch.inverse(models_to_final[i]), timg1.shape[-2:])
warped_mask = warped_mask * warped + (1 - warped_mask) * curr_img
curr_img if verbose:
= plt.subplots(1, 4, figsize=(15, 6), sharex="all", sharey="all", frameon=False)
fig, axis 0].imshow(K.tensor_to_image(timg1))
axis[1].imshow(K.tensor_to_image(mask1_blur))
axis[2].imshow(K.tensor_to_image(timg2))
axis[3].imshow(K.tensor_to_image(curr_img))
axis[0].set_title("Img1")
axis[1].set_title("Sharp mask on img1")
axis[2].set_title("Img2")
axis[3].set_title("Blended image")
axis[return curr_img
= K.io.load_image(fnames[50], K.io.ImageLoadType.RGB32, device=device)[None, ...]
timg1 = K.io.load_image(fnames[-1], K.io.ImageLoadType.RGB32, device=device)[None, ...]
timg2 = merge_sharp1_into2(timg1, timg2, models_to_final[50], True) out
The blending does not look really good, but that is because we are trying to merge non-consequtive images with very different focus. Let’s try to apply it sequentially and see, what happens.
We will also create a video of our sharpening process.
%%capture
= K.io.load_image(fnames[-1], K.io.ImageLoadType.RGB32, device=device)[None, ...]
base_img = deepcopy(base_img)
curr_img
try:
= imageio.get_writer("sharpening.avi", fps=8)
video_writer * 255.0).astype(np.uint8))
video_writer.append_data((K.tensor_to_image(curr_img) = True
video_ok except:
= False
video_ok
with torch.no_grad():
for i, fname in tqdm(enumerate(fnames)):
= K.io.load_image(fname, K.io.ImageLoadType.RGB32, device=device)[None, ...]
timg = merge_sharp1_into2(timg.to(device), curr_img.to(device), models_to_final[i].to(device))
curr_img if video_ok:
* 255.0).astype(np.uint8))
video_writer.append_data((K.tensor_to_image(curr_img) if video_ok:
video_writer.close()
float()))
plt.imshow(K.tensor_to_image(curr_img."Final result") plt.title(
Text(0.5, 1.0, 'Final result')
Now we can play the video of our sharpening. The code is ugly to allow running from Google Colab (as shown here)
from base64 import b64encode
from IPython.display import HTML
if video_ok:
= open("sharpening.avi", "rb").read()
mp4 else:
= open(get_data_directory("") + "sharpening.mp4", "rb").read()
mp4 = "data:video/mp4;base64," + b64encode(mp4).decode()
data_url
HTML(f"""
<video width=400 controls>
<source src="{data_url}" type="video/mp4">
</video>
"""
)
Result looks quite nice and more detailed, although a bit soft. You can try yourself different blending parameters yourself (e.g. blur kernel size) in order to improve the final result.