osc.viz.rollout.slot_attn_rollout

slot_attn_rollout(attns, normalize='layer')[source]

Slot attention rollout: how much a slot token attends to context tokens.

Parameters
  • attns – dict or list where each entry has shape [B S K]

  • normalizelayer to normalize the rollout at every layer, or all to normalize at the end only.

Returns

Rollout, shape [B S K]