%%capture
!pip install kornia numpy matplotlib
ZCA Whitening
kornia.enhance.zca
. The documentation for ZCA whitening can be found here.
Install necessary packages
# Import required libraries
import kornia as K
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, transforms
from torchvision.utils import make_grid
# Select runtime device
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device print(f"Using {device}")
Using cuda:0
ZCA on MNIST
Download and load the MNIST dataset.
%%capture
= datasets.MNIST("./data/mnist", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()])) dataset
Stack whole dataset in order to fit ZCA on whole dataset.
= []
images for i in range(len(dataset)):
= dataset[i]
im, _
images.append(im)= torch.stack(images, dim=0).to(device) images
Create an ZCA object and fit the transformation in the forward pass. Setting include_fit
is necessary if you need to include the ZCA fitting processing the backwards pass.
= K.enhance.ZCAWhitening(eps=0.1)
zca = zca(images, include_fit=True) images_zca
The result shown should enhance the edges of the MNIST digits because the regularization parameter \(\epsilon\) increases the importance of the higher frequencies which typically correspond to the lowest eigenvalues in ZCA. The result looks similar to the demo from the Stanford ZCA tutorial
= make_grid(images[0:30], nrow=5, normalize=True).cpu().numpy()
grid_im = make_grid(images_zca[0:30], nrow=5, normalize=True).cpu().numpy()
grid_zca
=(15, 15))
plt.figure(figsize1, 2, 1)
plt.subplot(1, 2, 0]))
plt.imshow(np.transpose(grid_im, ["Input Images")
plt.title(
plt.xticks([])
plt.yticks([])1, 2, 2)
plt.subplot(1, 2, 0]))
plt.imshow(np.transpose(grid_zca, [r"ZCA Images $\epsilon = 0.1$")
plt.title(
plt.xticks([])
plt.yticks([]) plt.show()
ZCA on CIFAR-10
In the next example, we explore using ZCA on the CIFAR 10 dataset, which is a dataset of color images (e.g 4-D tensor \([B, C, H, W]\)). In the cell below, we prepare the dataset.
%%capture
= datasets.CIFAR10("./data/cifar", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
dataset = []
images for i in range(len(dataset)):
= dataset[i]
im, _
images.append(im)= torch.stack(images, dim=0).to(device) images
We show another way to fit the ZCA transform using the fit
method useful when ZCA is included in data augumentation pipelines. Also if compute_inv = True
, this enables the computation of inverse ZCA transform in case a reconstruction is required.
= K.enhance.ZCAWhitening(eps=0.1, compute_inv=True)
zca
zca.fit(images)= zca(images)
zca_images = zca.inverse_transform(zca_images) image_re
Note how the higher frequency details are more present in the ZCA normalized images for CIFAR-10 dataset.
= make_grid(images[0:30], nrow=5, normalize=True).cpu().numpy()
grid_im = make_grid(zca_images[0:30], nrow=5, normalize=True).cpu().numpy()
grid_zca = make_grid(image_re[0:30], nrow=5, normalize=True).cpu().numpy()
grid_re
= grid_re - grid_im # Compute error image
err_grid
=(15, 15))
plt.figure(figsize2, 2, 1)
plt.subplot(1, 2, 0]))
plt.imshow(np.transpose(grid_im, ["Input Images")
plt.title(
plt.xticks([])
plt.yticks([])2, 2, 2)
plt.subplot(1, 2, 0]))
plt.imshow(np.transpose(grid_zca, [r"ZCA Images $\epsilon = 0.1$")
plt.title(
plt.xticks([])
plt.yticks([])2, 2, 3)
plt.subplot(1, 2, 0]))
plt.imshow(np.transpose(grid_re, [r"Reconstructed Images")
plt.title(
plt.xticks([])
plt.yticks([])2, 2, 4)
plt.subplot(sum(abs(np.transpose(err_grid, [1, 2, 0])), axis=-1))
plt.imshow(np.
plt.colorbar()"Error Image")
plt.title(
plt.xticks([])
plt.yticks([]) plt.show()
Differentiability of ZCA
We will as simple Gaussian dataset with a mean (3,3) and a diagonal covariance of 1 and 9 to explore the differentiability of ZCA.
= 100 # Number of points in the dataset
num_data 1234)
torch.manual_seed(= torch.cat([torch.randn((num_data, 1), requires_grad=True), 3 * torch.randn((num_data, 1), requires_grad=True)], dim=1) + 3
x
0], x.detach().numpy()[:, 1])
plt.scatter(x.detach().numpy()[:, -10, 10])
plt.xlim([-10, 10])
plt.ylim([ plt.show()
Here we explore the affect of the detach_transform
option when computing the backwards pass.
= K.enhance.ZCAWhitening(eps=1e-6, detach_transforms=True)
zca_detach = K.enhance.ZCAWhitening(eps=1e-6, detach_transforms=False) zca_grad
As a sanity check, the Jacobian between the input and output of the ZCA transform should be same for all data points in the detached case since the transform acts as linear transform (e.g \(T(X-\mu)\)). In the non-detached case, the Jacobian should vary across datapoints since the input affects the computation of the ZCA transformation matrix (e.g. \(T(X)(X-\mu(X))\)). As the number of samples increases, the Jacobians in the detached and non-detached cases should be roughly the same since the influence of a single datapoint decreases. You can test this by changing num_data
. Also note that include_fit=True
is included in the forward pass since creation of the transform matrix needs to be included in the forward pass in order to correctly compute the backwards pass.
import torch.autograd as autograd
= autograd.functional.jacobian(lambda x: zca_detach(x, include_fit=True), x)
J
= 5
num_disp print(f"Jacobian matrices detached for the first {num_disp} points")
for i in range(num_disp):
print(J[i, :, i, :])
print("\n")
= autograd.functional.jacobian(lambda x: zca_grad(x, include_fit=True), x)
J print(f"Jacobian matrices attached for the first {num_disp} points")
for i in range(num_disp):
print(J[i, :, i, :])
Jacobian matrices detached for the first 5 points
tensor([[ 1.0177, -0.0018],
[-0.0018, 0.3618]])
tensor([[ 1.0177, -0.0018],
[-0.0018, 0.3618]])
tensor([[ 1.0177, -0.0018],
[-0.0018, 0.3618]])
tensor([[ 1.0177, -0.0018],
[-0.0018, 0.3618]])
tensor([[ 1.0177, -0.0018],
[-0.0018, 0.3618]])
Jacobian matrices attached for the first 5 points
tensor([[ 1.0003, -0.0018],
[-0.0018, 0.3547]])
tensor([[ 1.0006, -0.0027],
[-0.0027, 0.3555]])
tensor([[ 0.9911, -0.0028],
[-0.0028, 0.3506]])
tensor([[9.9281e-01, 4.4671e-04],
[4.4671e-04, 3.5368e-01]])
tensor([[ 1.0072, -0.0019],
[-0.0019, 0.3581]])
Lastly, we plot the ZCA whitened data. Note that setting the include_fit
to True
stores the resulting transformations for future use.
= zca_detach(x).detach().numpy()
x_zca 0], x_zca[:, 1])
plt.scatter(x_zca[:, -4, 4])
plt.ylim([-4, 4])
plt.xlim([ plt.show()