Attention visualization utils: single attn maps and rollouts.
from itertools import product

import numpy as np
import torch
from einops import rearrange, reduce
from IPython.core.display import Image
from IPython.core.display_functions import display
from matplotlib import pyplot as plt

from .rollout import self_attn_rollout, slot_attn_rollout

[docs]def viz_vit_attns(images, attns_dict, head_reduction="mean", adjust_residual=True): A, B = images.shape[:2] images = images.detach().cpu().numpy() images = rearrange(images, "A B C H W -> B A H W C") L = len(attns_dict) attns = [attns_dict[k].detach() for k in sorted(attns_dict.keys())] attns = rearrange(attns, "L (A B) h Q_hw K_hw -> B A L h Q_hw K_hw", A=A, B=B) num_heads = attns.shape[3] K_h = K_w = int(np.sqrt(attns.shape[5])) # For each layer: # - sum over queries to find most attended key, average over heads # - sum over queries to find most attended key, for each head # - rearrange keys into a square image attns = torch.concat( [ reduce(attns, "B A L h Q_hw K_hw -> B A L 1 K_hw", "sum"), reduce(attns, "B A L h Q_hw K_hw -> B A L h K_hw", "sum"), ], dim=3, ) attns = rearrange( attns.cpu().numpy(), "B A L h (K_h K_w) -> B A L h K_h K_w", K_h=K_h, K_w=K_w ) # Rollout rollout = self_attn_rollout( attns_dict, head_reduction=head_reduction, adjust_residual=adjust_residual, global_avg_pool=True, ) rollout = rearrange( rollout.cpu().numpy(), "(A B) (K_h K_w) -> B A K_h K_w", A=A, B=B, K_h=K_h, K_w=K_w, ) # One figure per pair of augmented images for b in range(B): fig, axs = plt.subplots( 1 + L + 1, A * (num_heads + 1), figsize=2 * np.array( [ 0.2 + A * (num_heads + 1), 0.5 + 1 + L + 1, ] ), ) # Top row: images (leftmost only, not repeated) axs_img = axs[0, :].reshape(A, num_heads + 1) axs_img[0, 0].set_ylabel("Input") for a in range(A): ax = axs_img[a, 0] ax.imshow(images[b, a]) ax.set_xticks([]) ax.set_yticks([]) ax.set_title(f"Aug {a}") for ax in axs_img[:, 1:].flat: ax.set_axis_off() # Mid rows: one layer per row, avg all heads 1st column, then individual heads axs_heads = axs[1:-1, :].reshape(L, A, num_heads + 1) for lyr in range(L): axs_heads[lyr, 0, 0].set_ylabel(f"Layer {lyr}") for a in range(A): for h in range(num_heads + 1): ax = axs_heads[lyr, a, h] ax.imshow( attns[b, a, lyr, h], cmap="inferno" if h == 0 else "viridis" ) ax.set_xticks([]) ax.set_yticks([]) if lyr == 0: ax.set_title("Avg heads" if h == 0 else f"Head {h-1}") # Bottom row: rollout (leftmost only, not repeated) axs_roll = axs[-1, :].reshape(A, num_heads + 1) axs_roll[0, 0].set_ylabel("Rollout") for a in range(A): ax = axs_roll[a, 0] ax.imshow(rollout[b, a], cmap="inferno") ax.set_xticks([]) ax.set_yticks([]) for ax in axs_roll[:, 1:].flat: ax.set_axis_off() fig.suptitle(f"Image {b}") fig.tight_layout() fig.set_facecolor("white") fig.savefig(f"img-{b}-vit-attns.png", dpi=200) plt.close(fig) display(Image(url=f"img-{b}-vit-attns.png", width=1400))
[docs]def viz_slot_attns( images, vit_attns_dict, slot_attns_dict, head_reduction="mean", adjust_residual=True ): num_augs = images.shape[0] batch_size = images.shape[1] images = images.detach().cpu().numpy() images = rearrange(images, "A B C H W -> B A H W C") attns = [slot_attns_dict[i].detach() for i in sorted(slot_attns_dict.keys())] attns = rearrange( attns, "I (A B) k K_hw -> B A I k K_hw", A=num_augs, B=batch_size, ) attns = attns + 1e-8 attns = attns / attns.sum(axis=-1, keepdims=True) num_iters = attns.shape[2] num_slots = attns.shape[3] K_h = K_w = int(np.sqrt(attns.shape[-1])) attns = rearrange( attns, "B A I k (K_h K_w) -> B A I k K_h K_w", K_h=K_h, K_w=K_w, ) attns = attns.cpu().numpy() # [A*B, Q_hw, K_hw] vit_rollout = self_attn_rollout( vit_attns_dict, head_reduction=head_reduction, adjust_residual=adjust_residual, global_avg_pool=False, ) # [A*B, K, K_hw] slot_rollouts = slot_attn_rollout(slot_attns_dict) # [A*B, K, K_hw] full_rollouts = torch.einsum("bki,bij->bkj", slot_rollouts, vit_rollout) slot_rollouts = rearrange( slot_rollouts.cpu().numpy(), "(A B) k (K_h K_w) -> B A k K_h K_w", A=num_augs, B=batch_size, K_h=K_h, K_w=K_w, ) full_rollouts = rearrange( full_rollouts.cpu().numpy(), "(A B) k (K_h K_w) -> B A k K_h K_w", A=num_augs, B=batch_size, K_h=K_h, K_w=K_w, ) for b in range(batch_size): fig, axs = plt.subplots( 1 + num_iters + 1 + 1, num_augs * num_slots, figsize=2 * np.array([num_augs * num_slots, 0.5 + 1 + num_iters + 1 + 1]), squeeze=False, ) # Top row: images axs_img = axs[0, :].reshape(num_augs, num_slots) for a in range(num_augs): ax = axs_img[a, 0] ax.imshow(images[b, a]) ax.set_xticks([]) ax.set_yticks([]) ax.set_title(f"Aug {a}") for ax in axs_img[:, 1:].flat: ax.set_axis_off() # Mid rows: slot attn iterations axs_slt = axs[1:-2, :].reshape(num_iters, num_augs, num_slots) for i in range(num_iters): axs_slt[i, 0, 0].set_ylabel(f"Iteration {i}") for a in range(num_augs): # vmin = attns[b, a, i, :].min().item() # vmax = attns[b, a, i, :].max().item() # print(b,a,i,vmin,vmax) for k in range(num_slots): ax = axs_slt[i, a, k] # ax.imshow(attns[b, a, i, k], vmin=vmin, vmax=vmax) ax.imshow(attns[b, a, i, k]) ax.set_xticks([]) ax.set_yticks([]) if i == 0: ax.set_title(f"Slot {k}") # Bottom row 1: slot rollout for each slot axs_roll = axs[-2, :].reshape(num_augs, num_slots) axs_roll[0, 0].set_ylabel("Slot rollouts") for a in range(num_augs): for k in range(num_slots): ax = axs_roll[a, k] ax.imshow(slot_rollouts[b, a, k], cmap="inferno") ax.set_xticks([]) ax.set_yticks([]) # Bottom row 2: full rollout for each slot axs_roll = axs[-1, :].reshape(num_augs, num_slots) axs_roll[0, 0].set_ylabel("Full rollouts") for a in range(num_augs): for k in range(num_slots): ax = axs_roll[a, k] ax.imshow(full_rollouts[b, a, k], cmap="inferno") ax.set_xticks([]) ax.set_yticks([]) fig.suptitle(f"Image {b}") fig.tight_layout() fig.set_facecolor("white") fig.savefig(f"img-{b}-slot-attns.png", dpi=200) plt.close(fig) display(Image(url=f"img-{b}-slot-attns.png", width=1700))
[docs]def viz_vit_rollout_all_options(images, attns_dict): """Visualize backbone rollout: residual adjustment, head reduction""" A, B = images.shape[:2] images = images.detach().cpu().numpy() images = rearrange(images, "A B C H W -> B A H W C") K_h = K_w = int(np.sqrt(list(attns_dict.values())[0].shape[-1])) fig, axs = plt.subplots(B * A, 1 + 4, figsize=2.5 * (np.array([1 + 4, B * A]))) axs = rearrange(axs, "(B A) L -> B A L", B=B, A=A) for b, a in np.ndindex(B, A): axs[b, a, 0].imshow(images[b, a]) axs[b, a, 0].set_ylabel(f"Img {b}\nAug {a}") for i, (adjust_residual, head_reduction) in enumerate( product([True, False], ["mean", "max"]) ): rollout = self_attn_rollout( attns_dict, head_reduction=head_reduction, adjust_residual=adjust_residual, global_avg_pool=True, ) rollout = rearrange( rollout.cpu().numpy(), "(A B) (K_h K_w) -> B A K_h K_w", A=A, B=B, K_h=K_h, K_w=K_w, ) axs[0, 0, i + 1].set_title(f"{head_reduction} residual={adjust_residual}") for b, a in np.ndindex(B, A): axs[b, a, i + 1].imshow(rollout[b, a]) for ax in axs.flat: ax.set_xticks([]) ax.set_yticks([]) fig.tight_layout() fig.set_facecolor("white") return fig
[docs]def viz_slot_rollout_all_options(images, slot_attns_dict, vit_rollout): """Visualize slot attention rollout: per-iteration or final normalization.""" A, B = images.shape[:2] images = images.detach().cpu().numpy() images = rearrange(images, "A B C H W -> B A H W C") K_h = K_w = int(np.sqrt(list(slot_attns_dict.values())[0].shape[-1])) S = list(slot_attns_dict.values())[0].shape[-2] fig, axs = plt.subplots( B * A * 4, 1 + S, figsize=2 * (np.array([1 + S, B * A * 4])) ) axs = rearrange(axs, "(B A opt) Sp -> B A opt Sp", B=B, A=A, opt=4) for b, a in np.ndindex(B, A): axs[b, a, 0, 0].imshow(images[b, a]) for opt, (full, normalize) in enumerate(product([False, True], ["layer", "all"])): slot_rollout = slot_attn_rollout(slot_attns_dict, normalize=normalize) if full: slot_rollout = torch.einsum("bsj,bjk->bsk", slot_rollout, vit_rollout) slot_rollout = rearrange( slot_rollout.cpu().numpy(), "(A B) S (K_h K_w) -> B A S K_h K_w", A=A, B=B, K_h=K_h, K_w=K_w, ) for b, a in np.ndindex(B, A): for s in range(S): axs[b, a, opt, s + 1].imshow(slot_rollout[b, a, s]) axs[b, a, opt, 0].set_ylabel(f'{normalize}\n{"full" if full else "slot"}') for ax in axs.flat: ax.set_xticks([]) ax.set_yticks([]) fig.tight_layout() fig.set_facecolor("white") return fig