Source code for osc.viz.backbone

"""
Clustering backbone features.
"""
from typing import Tuple, Union

import numpy as np
import torch
from einops import rearrange
from IPython.core.display import Image
from IPython.core.display_functions import display
from matplotlib import pyplot as plt
from sklearn.cluster import KMeans
from torch import Tensor


[docs]def kmeans_backbone(images, f_backbone, num_patches: Tuple[int, int]): assert f_backbone.shape[1] == np.prod(num_patches) num_augs = images.shape[0] batch_size = images.shape[1] f_backbone = rearrange( f_backbone.detach().cpu().numpy(), "(A B) P_hw C -> B (A P_hw) C", A=num_augs, B=batch_size, ) images = rearrange( images.detach().cpu().numpy(), "A B C H W -> B A H W C", ) for b in range(batch_size): fig, axs = plt.subplots(2, num_augs, figsize=np.array(num_patches[::-1]) / 2) kmeans = KMeans(init="k-means++", n_clusters=11, n_init=4, random_state=0) clust = kmeans.fit_predict(f_backbone[b]).reshape(num_augs, *num_patches) for a in range(num_augs): axs[0, a].imshow(images[b, a]) axs[0, a].set_title(f"Aug {a}") axs[1, a].imshow(clust[a], cmap="tab20") for ax in axs[:-1, :].flat: ax.set_xticks([]) for ax in axs[:, 1:].flat: ax.set_yticks([]) fig.suptitle(f"Image {b} - backbone features K-Means") fig.tight_layout() fig.set_facecolor("white") fig.savefig(f"img-{b}-vit-kmeans.png", dpi=200) plt.close(fig) display(Image(url=f"img-{b}-vit-kmeans.png", width=600))
[docs]def kmeans_clusters(x: Union[Tensor, np.ndarray], n_clusters: int = 11) -> np.ndarray: """Batched K-means clustering. Args: x: N samples of C-dimensional features with leading batch dimensions, e.g. shape [..., N, C] n_clusters: desired number of clusters Returns: int array of cluster IDs, shape [..., N, C] """ if torch.is_tensor(x): x = x.detach().cpu().numpy() res = np.empty(x.shape[:-1], dtype=int) for b in np.ndindex(x.shape[:-2]): kmeans = KMeans( init="k-means++", n_clusters=n_clusters, n_init=4, random_state=0 ) res[b] = kmeans.fit_predict(x[b]) return res