%%capture
!pip install kornia
!pip install kornia-rs
Denoise image using total variation
Advanced
Denoising
In this tutorial we are going to learn how to denoise an image using the differentiable
total_variation
loss.
import io
import requests
def download_image(url: str, filename: str = "") -> str:
= url.split("/")[-1] if len(filename) == 0 else filename
filename # Download
= io.BytesIO(requests.get(url).content)
bytesio # Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
= "https://github.com/kornia/data/raw/main/doraemon.png"
url download_image(url)
'doraemon.png'
import kornia as K
import matplotlib.pyplot as plt
import torch
import torchvision
def imshow(input: torch.Tensor):
= torchvision.utils.make_grid(input, nrow=2, padding=5)
out = K.utils.tensor_to_image(out)
out_np
plt.imshow(out_np)"off")
plt.axis( plt.show()
# read the image with kornia and add a random noise to it
= K.io.load_image("doraemon.png", K.io.ImageLoadType.RGB32) # CxHxW
img
= (img + torch.normal(torch.zeros_like(img), 0.1)).clamp(0, 1)
noisy_image imshow(noisy_image)
We define the total variation denoising network and the optimizer
# define the total variation denoising network
class TVDenoise(torch.nn.Module):
def __init__(self, noisy_image):
super().__init__()
self.l2_term = torch.nn.MSELoss(reduction="mean")
self.regularization_term = K.losses.TotalVariation()
# create the variable which will be optimized to produce the noise free image
self.clean_image = torch.nn.Parameter(data=noisy_image.clone(), requires_grad=True)
self.noisy_image = noisy_image
def forward(self):
return self.l2_term(self.clean_image, self.noisy_image) + 0.0001 * self.regularization_term(self.clean_image)
def get_clean_image(self):
return self.clean_image
= TVDenoise(noisy_image)
tv_denoiser
# define the optimizer to optimize the 1 parameter of tv_denoiser
= torch.optim.SGD(tv_denoiser.parameters(), lr=0.1, momentum=0.9) optimizer
Run the the optimization loop
int = 500
num_iters: for i in range(num_iters):
optimizer.zero_grad()= tv_denoiser().sum()
loss if i % 50 == 0:
print(f"Loss in iteration {i} of {num_iters}: {loss.item():.3f}")
loss.backward() optimizer.step()
Loss in iteration 0 of 500: 3.081
Loss in iteration 50 of 500: 2.723
Loss in iteration 100 of 500: 2.359
Loss in iteration 150 of 500: 2.064
Loss in iteration 200 of 500: 1.828
Loss in iteration 250 of 500: 1.642
Loss in iteration 300 of 500: 1.497
Loss in iteration 350 of 500: 1.384
Loss in iteration 400 of 500: 1.297
Loss in iteration 450 of 500: 1.229
Visualize the noisy and resulting cleaned image
# convert back to numpy
= K.utils.tensor_to_image(tv_denoiser.get_clean_image())
img_clean
# Create the plot
= plt.subplots(1, 2, figsize=(16, 10))
fig, axs = axs.ravel()
axs
0].axis("off")
axs[0].set_title("Noisy image")
axs[0].imshow(K.tensor_to_image(noisy_image))
axs[
1].axis("off")
axs[1].set_title("Cleaned image")
axs[1].imshow(img_clean)
axs[
plt.show()