%%capture
!pip install kornia
!pip install kornia-rs
Connected Components Algorithm
Basic
Segmentation
Labeling
Unsupervised
kornia.contrib
In this tutorial we are going to learn how to segment small objects in the image using the kornia implementation of the classic Computer Vision technique called Connected-component labelling (CCL).
import io
import requests
def download_image(url: str, filename: str = "") -> str:
= url.split("/")[-1] if len(filename) == 0 else filename
filename # Download
= io.BytesIO(requests.get(url).content)
bytesio # Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
= "https://github.com/kornia/data/raw/main/cells_binary.png"
url download_image(url)
'cells_binary.png'
from __future__ import annotations
import kornia as K
import matplotlib.pyplot as plt
import torch
We define utility functions to visualize the segmentation properly
def create_random_labels_map(classes: int) -> dict[int, tuple[int, int, int]]:
int, Tuple[int, int, int]] = {}
labels_map: Dict[for i in classes:
= torch.randint(0, 255, (3,))
labels_map[i] 0] = torch.zeros(3)
labels_map[return labels_map
def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor:
"""Function that given an image with labels ids and their pixels intrensity mapping, creates a RGB
representation for visualisation purposes."""
assert len(img_labels.shape) == 2, img_labels.shape
= img_labels.shape
H, W = torch.empty(3, H, W, dtype=torch.uint8)
out for label_id, label_val in labels_map.items():
= img_labels == label_id
mask for i in range(3):
out[i].masked_fill_(mask, label_val[i])return out
def show_components(img, labels):
= torch.unique(labels)
color_ids = create_random_labels_map(color_ids)
labels_map = labels_to_image(labels, labels_map)
labels_img
= plt.subplots(1, 2, figsize=(12, 12))
fig, (ax1, ax2)
# Showing Original Image
ax1.imshow(img)"off")
ax1.axis("Orginal Image")
ax1.set_title(
# Showing Image after Component Labeling
1, 2, 0).squeeze().numpy())
ax2.imshow(labels_img.permute("off")
ax2.axis("Component Labeling")
ax2.set_title(
plt.show()
We load the image using Kornia
= K.io.load_image("cells_binary.png", K.io.ImageLoadType.GRAY32)[None, ...]
img_t
print(img_t.shape)
torch.Size([1, 1, 602, 602])
Apply the Connected-component labelling algorithm using the kornia.contrib.connected_components
functionality. The num_iterations
parameter will control the total number of iterations of the algorithm to finish until it converges to a solution.
= K.contrib.connected_components(img_t, num_iterations=150)
labels_out print(labels_out.shape)
torch.Size([1, 1, 602, 602])
show_components(img_t.numpy().squeeze(), labels_out.squeeze())
We can also explore the labels
print(torch.unique(labels_out))
tensor([ 0., 13235., 24739., 31039., 32177., 44349., 59745., 61289.,
66209., 69449., 78869., 94867., 101849., 102217., 102319., 115227.,
115407., 137951., 138405., 150047., 158715., 162179., 170433., 170965.,
174279., 177785., 182867., 210145., 212647., 215451., 216119., 221291.,
222367., 226183., 226955., 248757., 252823., 255153., 263337., 265505.,
270299., 270649., 277725., 282775., 296897., 298545., 299793., 300517.,
313961., 316217., 321259., 322235., 335599., 337037., 340289., 347363.,
352235., 352721., 360801., 360903., 360965., 361073., 361165., 361197.])