osc.viz.rollout.slot_attn_rollout
- slot_attn_rollout(attns, normalize='layer')[source]
Slot attention rollout: how much a slot token attends to context tokens.
- Parameters
attns – dict or list where each entry has shape
[B S K]
normalize – layer to normalize the rollout at every layer, or all to normalize at the end only.
- Returns
Rollout, shape
[B S K]