Source code for osc.viz.rollout

from typing import Callable, Mapping, Sequence, Union

import torch
from einops import reduce
from einops.layers.torch import Reduce
from torch import Tensor

from osc.utils import fill_diagonal_, normalize_sum_to_one


[docs]def self_attn_rollout( attns: Union[Mapping[str, Tensor], Sequence[Tensor]], head_reduction: Union[str, Callable] = "mean", adjust_residual=True, global_avg_pool=True, ): """Self-attn rollout: how much output token(s) attend to input tokens across layers Args: attns: dict or list where each entry has shape [B heads Q K] head_reduction: 'mean', 'max', or a callable that reduces the head dimension adjust_residual: bool, whether to add 0.5 for the self connection global_avg_pool: bool, if the output of the final attention layer is avg-pooled into a single vector of features Returns: Rollout, shape [B Q K] if ``global_avg_pool=False`` else [B K] """ if isinstance(attns, Mapping): attns = [attns[k] for k in sorted(attns.keys())] if isinstance(attns, Sequence): attns = [a.detach() for a in attns] # Reduce heads: mean or max if head_reduction in {"mean", "max"}: head_reduction = Reduce("B h Q K -> B Q K", head_reduction) attns = [head_reduction(a) for a in attns] # adjust for self-connections if adjust_residual: attns = [(a + torch.eye(*a.shape[-2:], device=a.device)) / 2 for a in attns] rollout = attns[0] for layer_idx in range(1, len(attns)): rollout = torch.einsum("bij,bjk->bik", attns[layer_idx], rollout) # Last layer is global avg pool, i.e. a single query token # that attends to all keys with uniform attention if global_avg_pool: rollout = reduce(rollout, "B Q K -> B K", "mean") return rollout
[docs]def slot_attn_rollout(attns, normalize="layer"): """Slot attention rollout: how much a slot token attends to context tokens. Args: attns: dict or list where each entry has shape ``[B S K]`` normalize: `layer` to normalize the rollout at every layer, or `all` to normalize at the end only. Returns: Rollout, shape ``[B S K]`` """ if isinstance(attns, Mapping): attns = [attns[k] for k in sorted(attns.keys())] if isinstance(attns, Sequence): attns = [a.detach() for a in attns] attns = [normalize_sum_to_one(a, dim=-1) for a in attns] if normalize == "layer": rollout = attns[0] for layer_idx in range(1, len(attns)): rollout = normalize_sum_to_one(attns[layer_idx] + rollout, dim=-1) elif normalize == "all": rollout = torch.sum(torch.stack(attns, dim=0), dim=0) rollout = normalize_sum_to_one(rollout, dim=-1) else: raise ValueError(normalize) return rollout
[docs]def cross_attn_rollout(attns): """Cross attention rollout: how much an object token attends to context tokens. Args: attns: dict or list where entries alternate self attention with shape ``[B heads S S]`` and cross attention with shape ``[B heads S K]``. Returns: Rollout, shape ``[B S K]`` """ if len(attns) % 2 != 0: raise ValueError("Must be pairs of self and cross attn") if isinstance(attns, Mapping): attns = list(attns.values()) if isinstance(attns, Sequence): attns = [a.detach().mean(dim=1) for a in attns] device = attns[0].device B, S, S = attns[0].shape B, S, K = attns[1].shape id_ss = fill_diagonal_(torch.zeros(B, S, S, device=device), 1.0) R_ss = id_ss.clone() R_sk = torch.zeros(B, S, K, device=device) for attn in attns: if attn.shape == (B, S, S): # Self attention R_ss.add_(torch.bmm(attn, R_ss)) R_sk.add_(torch.bmm(attn, R_sk)) elif attn.shape == (B, S, K): # Cross attention R_ss_bar = normalize_sum_to_one(R_ss - id_ss) + id_ss R_sk.add_(torch.bmm(R_ss_bar.transpose(-2, -1), attn)) else: raise ValueError(attn.shape) return normalize_sum_to_one(R_sk)