%%capture
!pip install kornia
!pip install kornia-rs
Visual Prompter: Segment Anything
This tutorials shows how to use our high-level API Visual Prompter. This API allow to set an image, and run multiple queries multiple times on this image. These query can be done with three types of prompt:
- Points: Keypoints with
(x, y)
and a respective label. Where 0 indicates a background point; 1 indicates a foreground point; - Boxes: Boxes of different regions.
- Masks: Logits generated by the model in a previous run.
Read more on our docs: https://kornia.readthedocs.io/en/latest/models/segment_anything.html
This tutorials steps:
- Setup the desired SAM model and import the necessary packages
- How to instantiate the Visual Prompter
- Using the Visual Prompter
Setup
First let’s choose the SAM type to be used on our Visual Prompter.
The options are (smaller to bigger):
model_type | checkpoint official |
---|---|
vit_b | https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth |
vit_l | https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth |
vit_h | https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth |
= "vit_h"
model_type = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" checkpoint
Then let’s import all necessary packages and modules
from __future__ import annotations
import os
import matplotlib.pyplot as plt
import torch
from kornia.contrib.models.sam import SamConfig
from kornia.contrib.visual_prompter import VisualPrompter
from kornia.geometry.boxes import Boxes
from kornia.geometry.keypoints import Keypoints
from kornia.io import ImageLoadType, load_image
from kornia.utils import get_cuda_or_mps_device_if_available, tensor_to_image
= get_cuda_or_mps_device_if_available()
device print(device)
None
Utilities functions
import io
import requests
def download_image(url: str, filename: str = "") -> str:
= url.split("/")[-1] if len(filename) == 0 else filename
filename # Download
= io.BytesIO(requests.get(url).content)
bytesio # Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
= download_image("https://raw.githubusercontent.com/kornia/data/main/soccer.jpg")
soccer_image_path = download_image("https://raw.githubusercontent.com/kornia/data/main/simple_car.jpg")
car_image_path = download_image("https://raw.githubusercontent.com/kornia/data/main/satellite_sentinel2_example.tif")
satellite_image_path soccer_image_path, car_image_path, satellite_image_path
('soccer.jpg', 'simple_car.jpg', 'satellite_sentinel2_example.tif')
def colorize_masks(binary_masks: torch.Tensor, merge: bool = True, alpha: None | float = None) -> list[torch.Tensor]:
"""Convert binary masks (B, C, H, W), boolean tensors, into masks with colors (B, (3, 4) , H, W) - RGB or RGBA. Where C refers to the number of masks.
Args:
binary_masks: a batched boolean tensor (B, C, H, W)
merge: If true, will join the batch dimension into a unique mask.
alpha: alpha channel value. If None, will generate RGB images
Returns:
A list of `C` colored masks.
"""
= binary_masks.shape
B, C, H, W = 4 if alpha else 3
OUT_C
= []
output_masks
for idx in range(C):
= torch.zeros(B, OUT_C, H, W, device=binary_masks.device, dtype=torch.float32)
_out for b in range(B):
= torch.rand(1, 3, 1, 1, device=binary_masks.device, dtype=torch.float32)
color if alpha:
= torch.cat([color, torch.tensor([[[[alpha]]]], device=binary_masks.device, dtype=torch.float32)], dim=1)
color
= binary_masks[b, idx, ...].view(1, 1, H, W).repeat(1, OUT_C, 1, 1)
to_colorize = torch.where(to_colorize, color, _out[b, ...])
_out[b, ...]
output_masks.append(_out)
if merge:
= [c.max(dim=0)[0] for c in output_masks]
output_masks
return output_masks
def show_binary_masks(binary_masks: torch.Tensor, axes) -> None:
"""plot binary masks, with shape (B, C, H, W), where C refers to the number of masks.
will merge the `B` channel into a unique mask.
Args:
binary_masks: a batched boolean tensor (B, C, H, W)
ax: a list of matplotlib axes with lenght of C
"""
= colorize_masks(binary_masks, True, 0.6)
colored_masks
for ax, mask in zip(axes, colored_masks):
ax.imshow(tensor_to_image(mask))
def show_boxes(boxes: Boxes, ax) -> None:
= boxes.to_tensor(mode="xywh").detach().cpu().numpy()
boxes_tensor for box in boxes_tensor:
= box
x0, y0, w, h ="orange", facecolor=(0, 0, 0, 0), lw=2))
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor
def show_points(points: tuple[Keypoints, torch.Tensor], ax, marker_size=200):
= points
coords, labels = coords[labels == 1].to_tensor().detach().cpu().numpy()
pos_points = coords[labels == 0].to_tensor().detach().cpu().numpy()
neg_points
0], pos_points[:, 1], color="green", marker="+", s=marker_size, linewidth=2)
ax.scatter(pos_points[:, 0], neg_points[:, 1], color="red", marker="x", s=marker_size, linewidth=2) ax.scatter(neg_points[:,
from kornia.contrib.models import SegmentationResults
def show_image(image: torch.Tensor):
plt.imshow(tensor_to_image(image))"off")
plt.axis(
plt.show()
def show_predictions(
image: torch.Tensor,
predictions: SegmentationResults,tuple[Keypoints, torch.Tensor] | None = None,
points: | None = None,
boxes: Boxes -> None:
) = predictions.logits.shape[1]
n_masks
= plt.subplots(1, n_masks, figsize=(21, 16))
fig, axes = [axes] if n_masks == 1 else axes
axes
for idx, ax in enumerate(axes):
= predictions.scores[:, idx, ...].mean()
score
ax.imshow(tensor_to_image(image))f"Mask {idx+1}, Score: {score:.3f}", fontsize=18)
ax.set_title(
if points:
show_points(points, ax)
if boxes:
show_boxes(boxes, ax)
"off")
ax.axis(
show_binary_masks(predictions.binary_masks, axes) plt.show()
Exploring the Visual Prompter
The VisualPrompter
can be initialized from a ModelConfig structure, where now we just have support for the SAM model through the SamConfig
. Through this config the VisualPrompter
will initialize the SAM model and load the weights (from a path or a URL).
What the VisualPrompter
can do? 1. Based on the ModelConfig, besides the model initialization, we will setup the required transformations for the images and prompts using the kornia.augmentation API within the Augmentation sequential container. 1. You can benefit from using the torch.compile(...)
API (dynamo) for torch >= 2.0.0 versions. To compile with dynamo we provide the method VisualPrompter.compile(...)
which will optimize the right parts of the backend model and the prompter itself. 1. Caching the image features and transformations. With the VisualPrompter.set_image(...)
method, we transform the image and already encode it using the model, caching it’s embeddings to query later. 1. Query multiple times with multiple prompts. Using the VisualPrompter.predict(...)
, where we will query on our cached embeddings using Keypoints, Boxes and Masks as prompt.
What the VisualPrompter
and Kornia provides? Easy high-levels structures to be used as prompt, also as the result of the prediction. Using the kornia geometry module you can easily encapsulate the Keypoints and Boxes, which allow the API to be more flexible about the desired mode (mainly for boxes, where we had multiple modes of represent it).
The Kornia VisualPrompter
and model config for SAM can be imported as follow:
from kornia.contrib.image_prompter import VisualPrompter
from kornia.contrib.models import SamConfig
# Setting up a SamConfig with the model type and checkpoint desired
= SamConfig(model_type, checkpoint)
config
# Initialize the VisualPrompter
= VisualPrompter(config, device=device) prompter
Set image
First, before adding the image to the prompter, we need to read the image. For that, we can use kornia.io, which internally uses kornia-rs. If you do not have kornia-rs
installed, you can install it with pip install kornia_rs
. This API implements the DLPack protocol natively in Rust to reduce the memory footprint during the decoding and type conversion. Allowing us to read the image from the disk directly to a tensor. Note that the image should be scaled within the range [0,1].
# Load the image
= load_image(soccer_image_path, ImageLoadType.RGB32, device) # 3 x H x W
image
# Display the loaded image
show_image(image)
With the image loaded onto the same device as the model, and with the right shape 3xHxW
, we can now set the image in our image prompter. Attention: when doing this, the model will compute the embeddings of this image; this means, we will pass this image through the encoder, which will use a lot of memory. It is possible to use the largest model (vit-h) with a graphic card (GPU) that has at least 8Gb of VRAM.
prompter.set_image(image)
If no error occurred, the features needed to run queries are now cached. If you want to check this, you can see the status of the prompter.is_image_set
property.
prompter.is_image_set
True
Examples of prompts
The VisualPrompter
output will have the same Batch Size that its prompts. Where the output shape will be (B, C, H, W). Where B
is the number of input prompts, C
is determined by multimask output parameter. If multimask_output
is True than C=3
, otherwise C=1
Keypoints
Keypoints prompts is a tensor or a Keypoint structure within coordinates into (x, y). With shape BxNx2.
For each coordinate pair, should have a corresponding label, where 0 indicates a background point; 1 indicates a foreground point; These labels should be in a tensor with shape BxN
The model will try to find a object within all the foreground points, and without the background points. In other words, the foreground points can be used to select the desired type of data, and the background point to exclude the type of data.
= torch.tensor([[[960, 540]]], device=device, dtype=torch.float32)
keypoints_tensor = Keypoints(keypoints_tensor)
keypoints
= torch.tensor([[1]], device=device, dtype=torch.float32) labels
Boxes
Boxes prompts is a tensor a with boxes on “xyxy” format/mode, or a Boxes structure. Tensor should have a shape of BxNx4.
= torch.tensor([[[1841.7, 739.0, 1906.5, 890.6]]], device=device, dtype=torch.float32)
boxes_tensor = Boxes.from_tensor(boxes_tensor, mode="xyxy") boxes
Masks
Masks prompts should be provide from a previous model output, with shape Bx1x256x256
# first run
= prompter.prediction(...)
predictions
# use previous results as prompt
= prompter.prediction(..., mask=predictions.logits) predictions
Example of prediction
# using keypoints
= prompter.predict(keypoints, labels, multimask_output=False)
prediction_by_keypoint
show_image(prediction_by_keypoint.binary_masks)
# Using boxes
= prompter.predict(boxes=boxes, multimask_output=False)
prediction_by_boxes
show_image(prediction_by_boxes.binary_masks)
Exploring the prediction result structure
The VisualPrompter
prediction structure, is a SegmentationResults
which has the upscaled (default) logits when output_original_size=True
is passed on the predict.
The segmentation results have: - logits: Results logits with shape (B, C, H, W)
, where C
refers to the number of predicted masks - scores: The scores from the logits. Shape (B,)
- Binary mask generated from logits considering the mask_threshold. The size depends on original_res_logits=True
, if false, the binary mask will have the same shape of the logits Bx1x256x256
prediction_by_boxes.scores
tensor([[0.9317]])
prediction_by_boxes.binary_masks.shape
torch.Size([1, 1, 1080, 1920])
prediction_by_boxes.logits.shape
torch.Size([1, 1, 256, 256])
Using the Visual Prompter on examples
Soccer players
Using an example image from the dataset: https://www.kaggle.com/datasets/ihelon/football-player-segmentation
Lets segment the persons on the field using boxes
# Prompts
= Boxes.from_tensor(
boxes
torch.tensor(
[1841.7000, 739.0000, 1906.5000, 890.6000],
[879.3000, 545.9000, 948.2000, 669.2000],
[55.7000, 595.0000, 127.4000, 745.9000],
[1005.4000, 128.7000, 1031.5000, 212.0000],
[387.4000, 424.1000, 438.2000, 539.0000],
[921.0000, 377.7000, 963.3000, 483.0000],
[1213.2000, 885.8000, 1276.2000, 1060.1000],
[40.8900, 725.9600, 105.4100, 886.5800],
[848.9600, 283.6200, 896.0600, 368.6200],
[1109.6500, 499.0400, 1153.0400, 622.1700],
[576.3000, 860.8000, 671.7000, 1018.8000],
[1039.8000, 389.9000, 1072.5000, 493.2000],
[1647.1000, 315.1000, 1694.0000, 406.0000],
[1231.2000, 214.0000, 1294.1000, 297.3000],
[
],=device,
device
),="xyxy",
mode )
# Load the image
= load_image(soccer_image_path, ImageLoadType.RGB32, device) # 3 x H x W
image
# Set the image
prompter.set_image(image)
= prompter.predict(boxes=boxes, multimask_output=True) predictions
let’s see the results, since we used multimask_output=True
, the model outputted 3 masks.
=boxes) show_predictions(image, predictions, boxes
Car parts
Segmenting car parts of an example from the dataset: https://www.kaggle.com/datasets/jessicali9530/stanford-cars-dataset
# Prompts
= Boxes.from_tensor(
boxes
torch.tensor(
[56.2800, 369.1100, 187.3000, 579.4300],
[412.5600, 426.5800, 592.9900, 608.1600],
[609.0800, 366.8200, 682.6400, 431.1700],
[925.1300, 366.8200, 959.6100, 423.1300],
[756.1900, 416.2300, 904.4400, 473.7000],
[489.5600, 285.2200, 676.8900, 343.8300],
[
],=device,
device
),="xyxy",
mode
)
= Keypoints(
keypoints
torch.tensor(535.0, 227.0], [349.0, 215.0], [237.0, 219.0], [301.0, 373.0], [641.0, 397.0], [489.0, 513.0]]], device=device
[[[
)
)= torch.ones(keypoints.shape[:2], device=device, dtype=torch.float32) labels
# Image
= load_image(car_image_path, ImageLoadType.RGB32, device)
image
# Set the image
prompter.set_image(image)
Querying with boxes
= prompter.predict(boxes=boxes, multimask_output=True) predictions
=boxes) show_predictions(image, predictions, boxes
Querying with keypoints
Considering N points into 1 Batch
This way the model will kinda find the object within all the points
= prompter.predict(keypoints=keypoints, keypoints_labels=labels) predictions
=(keypoints, labels)) show_predictions(image, predictions, points
Considering 1 point into N Batch
Prompter encoder not working for a batch of points :/
Considering 1 point for batch into N queries
This way the model will find an object for each point
= 2 # number of times/points to query
k
for idx in range(min(keypoints.data.size(1), k)):
print("-" * 79, f"\nQuery {idx}:")
= keypoints[:, idx, ...][None, ...]
_kpts = labels[:, idx, ...][None, ...]
_lbl
= prompter.predict(keypoints=_kpts, keypoints_labels=_lbl)
predictions
=(_kpts, _lbl)) show_predictions(image, predictions, points
-------------------------------------------------------------------------------
Query 0:
-------------------------------------------------------------------------------
Query 1:
Satellite image
Image from Sentinel-2
Product: A tile of the TCI (px of 10m). Product name: S2B_MSIL1C_20230324T130249_N0509_R095_T23KPQ_20230324T174312
# Prompts
= Keypoints(
keypoints
torch.tensor(
[
[# Urban
74.0, 104.5],
[335, 110],
[702, 65],
[636, 479],
[408, 820],
[# Forest
40, 425],
[680, 566],
[405, 439],
[73, 689],
[865, 460],
[# Ocean/water
981, 154],
[705, 714],
[357, 683],
[259, 908],
[1049, 510],
[
]
],=device,
device
)
)= torch.ones(keypoints.shape[:2], device=device, dtype=torch.float32) labels
# Image
= load_image(satellite_image_path, ImageLoadType.RGB32, device)
image
# Set the image
prompter.set_image(image)
Query urban points
# Query the prompts
= labels.clone()
labels_to_query 5:] = 0
labels_to_query[...,
= prompter.predict(keypoints=keypoints, keypoints_labels=labels_to_query) predictions
=(keypoints, labels_to_query)) show_predictions(image, predictions, points
Query Forest points
# Query the prompts
= labels.clone()
labels_to_query 5] = 0
labels_to_query[..., :10:] = 0
labels_to_query[...,
= prompter.predict(keypoints=keypoints, keypoints_labels=labels_to_query) predictions
=(keypoints, labels_to_query)) show_predictions(image, predictions, points
Query ocean/water points
# Query the prompts
= labels.clone()
labels_to_query 10] = 0
labels_to_query[..., :
= prompter.predict(keypoints=keypoints, keypoints_labels=labels_to_query) predictions
=(keypoints, labels_to_query)) show_predictions(image, predictions, points