"""
Self, cross, co, and slot attentions.
Attention modules implement the query-key-value projections,
the attention itself, and the output projections.
Block modules wrap an attention module with layer norm,
feed-forward layers and residual connections.
"""
from typing import Tuple
import numpy as np
import opt_einsum
import timm.models.vision_transformer
import torch
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath, Mlp
from torch import Tensor, nn
[docs]class SelfAttention(nn.Module):
"""Self attention.
Check the :func:`forward` method.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
"""Init.
Args:
dim:
num_heads:
qkv_bias:
attn_drop:
proj_drop:
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = np.sqrt(head_dim)
self.proj_qkv = nn.Sequential(
nn.Linear(dim, 3 * dim, bias=qkv_bias),
Rearrange(
"B L (three H C) -> three B L H C", three=3, H=num_heads, C=head_dim
),
)
# Made-up batch size and sequence lengths to precompute einops
B, L = 64, 128
BLHC = (B, L, num_heads, head_dim)
BHLL = (B, num_heads, L, L)
self.dot_fn = opt_einsum.contract_expression("bqhc, bkhc -> bhqk", BLHC, BLHC)
self.out_fn = opt_einsum.contract_expression("bhqk, bkhc -> bqhc", BHLL, BLHC)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_out = nn.Sequential(
Rearrange("B L H C -> B L (H C)"),
nn.Linear(dim, dim),
nn.Dropout(proj_drop),
)
[docs] def forward(self, x: Tensor) -> Tensor:
"""Forward.
Args:
x: tensor of shape ``[B L C]``
Returns:
Tensor of shape ``[B L C]``.
"""
# Keys K=L, queries K=L
# [B K HC] -> [3 B K H C] -> ([B K H C], [B K H C], [B K H C])
q, k, v = self.proj_qkv(x).unbind(dim=0)
# ([B Q H C], [B K H C]) -> [B H Q K]
dots = self.dot_fn(q, k)
attn = dots.softmax(dim=-1)
attn = self.attn_drop(attn)
# ([B H Q K], [B K H C]) -> [B Q H C]
out = self.out_fn(attn, v)
# [B Q H C] -> [B Q HC]
out = self.proj_out(out)
return out
[docs]class CrossAttention(nn.Module):
"""Cross attention.
Check the :func:`forward` method.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
"""Init.
Args:
dim:
num_heads:
qkv_bias:
attn_drop:
proj_drop:
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = np.sqrt(head_dim)
self.proj_q = nn.Sequential(
nn.Linear(dim, dim, bias=qkv_bias),
Rearrange("B Q (H C) -> B Q H C", H=num_heads, C=head_dim),
)
self.proj_kv = nn.Sequential(
nn.Linear(dim, 2 * dim, bias=qkv_bias),
Rearrange("B K (two H C) -> two B K H C", two=2, H=num_heads, C=head_dim),
)
# Made-up batch size and sequence lengths to precompute einops
B, Q, K = 64, 128, 128
BQHC = (B, Q, num_heads, head_dim)
BKHC = (B, K, num_heads, head_dim)
BHQK = (B, num_heads, Q, K)
self.dot_fn = opt_einsum.contract_expression("bqhc, bkhc -> bhqk", BQHC, BKHC)
self.out_fn = opt_einsum.contract_expression("bhqk, bkhc -> bqhc", BHQK, BKHC)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_out = nn.Sequential(
Rearrange("B Q H C -> B Q (H C)"),
nn.Linear(dim, dim),
nn.Dropout(proj_drop),
)
[docs] def forward(self, x: Tensor, ctx: Tensor) -> Tensor:
"""Forward.
Args:
x: query tensor of shape ``[B Q C]``
ctx: key-value tensor of shape ``[B K C]``
Returns:
Tensor of shape ``[B Q C]``.
"""
# x: [B Q C] queries
# ctx: [B K C] context (key-value)
# [B Q HC] -> [B Q H C]
q = self.proj_q(x).div(self.scale)
# [B K HC] -> [2 B K H C] -> ([B K H C], [B K H C])
k, v = self.proj_kv(ctx).unbind(dim=0)
# ([B Q H C], [B K H C]) -> [B H Q K]
dots = self.dot_fn(q, k)
attn = dots.softmax(dim=-1)
attn = self.attn_drop(attn)
# ([B H Q K], [B K H C]) -> [B Q H C]
out = self.out_fn(attn, v)
# [B Q H C] -> [B Q HC]
out = self.proj_out(out)
return out
[docs]class SlotAttention(nn.Module):
"""Slot attention.
Diagram::
TODO
Check the :func:`forward` method.
"""
def __init__(
self,
dim: int,
pos_embed=None,
iters: int = 3,
eps: float = 1e-8,
hidden_dim: int = None,
):
"""Init.
Args:
dim:
pos_embed:
iters:
eps:
hidden_dim:
"""
super().__init__()
self.num_iters = iters
self.eps = eps
self.scale = np.sqrt(dim)
if hidden_dim is None:
hidden_dim = dim
if pos_embed is None or pos_embed == "none":
pos_embed = nn.Identity()
self.pos_embed = pos_embed
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_k = nn.Linear(dim, dim, bias=False)
self.to_v = nn.Linear(dim, dim, bias=False)
# Dummy module to get the attn matrix with a pre-forward hook
self.dot_prod_softmax = nn.Identity()
self.gru = nn.GRUCell(dim, dim)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, dim),
)
self.norm_input = nn.LayerNorm(dim)
self.norm_slots = nn.LayerNorm(dim)
self.norm_pre_ff = nn.LayerNorm(dim)
[docs] def forward(self, slots: Tensor, inputs: Tensor) -> Tensor:
"""Forward.
Args:
slots: query tensor of shape ``[B S C]``
inputs: key-value tensor of shape ``[B N C]``
Returns:
Tensor of shape ``[B S C]``.
"""
B, S, C = slots.shape
B, N, C = inputs.shape
inputs = self.pos_embed(inputs)
inputs = self.norm_input(inputs)
k = self.to_k(inputs)
v = self.to_v(inputs)
for i in range(self.num_iters):
slots_prev = slots
slots = self.norm_slots(slots)
q = self.to_q(slots).div_(self.scale)
dots = torch.einsum("bqd,bkd->bqk", q, k)
attn = dots.softmax(dim=-2)
attn, _ = self.dot_prod_softmax((attn, i))
attn = attn + self.eps
attn = attn / attn.sum(dim=-1, keepdim=True)
updates = torch.einsum("bjd,bij->bid", v, attn)
slots = self.gru(
updates.reshape(-1, C),
slots_prev.reshape(-1, C),
).reshape(B, -1, C)
slots = slots + self.mlp(self.norm_pre_ff(slots))
return slots
[docs]class SelfAttentionBlock(nn.Module):
"""Self attention block.
Diagram::
┌───┤
│ ▼
│ norm
│ │
│ ▼
│ attn
│ │
│ ▼
└──►+
┌───┤
│ ▼
│ norm
│ │
│ ▼
│ MLP
│ │
│ ▼
└──►+
▼
Check the :func:`forward` method.
"""
# Reference implementation: timm.models.vision_transformer.Block
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
):
"""Init.
Args:
dim:
num_heads:
mlp_ratio:
qkv_bias:
drop:
attn_drop:
drop_path:
act_layer:
norm_layer:
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SelfAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(np.round(dim * mlp_ratio)),
act_layer=act_layer,
drop=drop,
)
[docs] def forward(self, x: Tensor) -> Tensor:
"""Forward.
Args:
x: query tensor of shape ``[B L C]``
Returns:
Tensor of shape ``[B L C]``.
"""
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
[docs]class CrossAttentionBlock(nn.Module):
"""Cross attention block.
Diagram::
x ctx
┌───┤ │
│ ▼ │
│ norm │
│ │ │
│ ▼ │
│ self │
│ attn │
│ │ │
│ ▼ │
└──►+ │
┌───┤ │
│ ▼ │
│ norm │
│ │ │
│ ▼ │
│ cross◄─┘
│ attn
│ │
│ ▼
└──►+
┌───┤
│ ▼
│ norm
│ │
│ ▼
│ MLP
│ │
│ ▼
└──►+
▼
Check the :func:`forward` method.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
"""Init.
Args:
dim:
num_heads:
mlp_ratio:
qkv_bias:
drop:
attn_drop:
drop_path:
act_layer:
norm_layer:
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.self_attn = SelfAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.norm2 = norm_layer(dim)
self.cross_attn = CrossAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm3 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(np.round(dim * mlp_ratio)),
act_layer=act_layer,
drop=drop,
)
[docs] def forward(self, x: Tensor, ctx: Tensor) -> Tensor:
"""Forward.
Args:
x: query tensor of shape ``[B Q C]``
ctx: key-value tensor of shape ``[B K C]``
Returns:
Tensor of shape ``[B Q C]``.
"""
x = x + self.drop_path(self.self_attn(self.norm1(x)))
x = x + self.drop_path(self.cross_attn(self.norm2(x), ctx))
x = x + self.drop_path(self.mlp(self.norm3(x)))
return x
[docs]class CoAttentionBlock(nn.Module):
"""Co-attention block, both a->b and b->a.
Check the :func:`forward` method.
"""
def __init__(self, *args, **kwargs):
"""Init.
Args:
*args: Same as :class:`CrossAttentionBlock`
**kwargs: Same as :class:`CrossAttentionBlock`
"""
super().__init__()
self.cross_a_b = CrossAttentionBlock(*args, **kwargs)
self.cross_b_a = CrossAttentionBlock(*args, **kwargs)
[docs] def forward(self, a: Tensor, b: Tensor) -> Tuple[Tensor, Tensor]:
"""Forward.
Args:
a: tensor of shape ``[B N C]``
b: tensor of shape ``[B M C]``
Returns:
Tensors with shapes:
``[B N C]`` (``a`` attending to ``b``),
``[B M C]`` (``b`` attending to ``a``).
"""
new_a = self.cross_a_b(a, b)
new_b = self.cross_b_a(b, a)
return new_a, new_b
[docs]class CrossAttentionDecoder(nn.Module):
def __init__(
self,
dim: int,
pos_embed=None,
num_heads: int = 4,
num_layers: int = 3,
block_drop: float = 0.0,
block_attn_drop: float = 0.0,
drop_path: float = 0.0,
mlp_ratio: float = 4.0,
):
"""Init.
Args:
dim:
pos_embed:
num_heads:
num_layers:
block_drop:
block_attn_drop:
drop_path:
mlp_ratio:
"""
super(CrossAttentionDecoder, self).__init__()
if pos_embed is None or pos_embed == "none":
pos_embed = nn.Identity()
self.pos_embed = pos_embed
self.inputs_norm = nn.LayerNorm(dim)
# B N C -> B N C
self.attn_blocks = nn.ModuleList(
[
CrossAttentionBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=False,
drop=block_drop,
attn_drop=block_attn_drop,
drop_path=drop_path,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
)
for _ in range(num_layers)
]
)
self.init_weights()
def init_weights(self):
# noinspection PyProtectedMember
self.apply(timm.models.vision_transformer._init_vit_weights)
[docs] def forward(self, obj_queries: Tensor, inputs: Tensor) -> Tensor:
"""Forward.
Args:
obj_queries: query tensor of shape ``[B S C]``
inputs: key-value tensor of shape ``[B N C]``
Returns:
Tensor of shape ``[B S C]``.
"""
inputs = self.inputs_norm(self.pos_embed(inputs))
for i, block in enumerate(self.attn_blocks):
obj_queries = block(obj_queries, inputs)
return obj_queries