Source code for osc.models.models

"""
Model definitions and wrappers to get attention maps.
"""

from contextlib import contextmanager
from typing import Dict, Generator, List, NamedTuple, Optional

import torch.nn as nn
from torch import Tensor
from torch.utils.hooks import RemovableHandle

from osc.models.attentions import CrossAttention, SelfAttention

from .attentions import SlotAttention
from .embeds import (
    KmeansCosineObjectTokens,
    KmeansEuclideanObjectTokens,
    LearnedObjectTokens,
    SampledObjectTokens,
)


[docs]class ModelOutput(NamedTuple): """Model output. Fields: - ``f_backbone``: backbone features, shape ``[B N C]`` - ``f_global``: global features, shape ``[B C]`` - ``f_slots``: object features, shape ``[B C]`` - ``p_global``: global projections, shape ``[B S C]`` - ``p_slots``: object projections, shape ``[B S C]`` """ f_backbone: Tensor f_global: Tensor f_slots: Optional[Tensor] p_global: Tensor p_slots: Optional[Tensor]
[docs]class Model(nn.Module): """Core model. See :func:``forward`` method. """ def __init__( self, *, architecture: str, backbone, global_fn, global_proj, obj_queries=None, obj_fn=None, obj_proj=None, ): super(Model, self).__init__() self.architecture = architecture self.backbone = backbone self.global_fn = global_fn self.global_proj = global_proj if architecture == "backbone-global_fn-global_proj": if any(m is not None for m in [obj_queries, obj_fn, obj_proj]): raise ValueError( f"With {architecture}, all `obj_*` parameters must be None." ) else: if any(m is None for m in [obj_queries, obj_fn, obj_proj]): raise ValueError( f"With {architecture}, all `obj_*` parameters must not be None." ) self.obj_queries = obj_queries self.obj_fn = obj_fn self.obj_proj = obj_proj
[docs] def forward(self, images: Tensor) -> ModelOutput: """Forward. Args: images: float tensor of shape ``[B 3 H W]`` Returns: A :class:``ModelOutput`` tuple containing per-patch backbone features, global features, global projections, object features, object projections. """ # f_backbone_pool: [B C] # f_backbone_patch: [B N C] with N = H*W // num_patches f_backbone_pool, f_backbone_patch = self.backbone(images) if self.architecture == "backbone-global_fn-global_proj": f_global = self.global_fn(f_backbone_pool) p_global = self.global_proj(f_global) f_slots = None p_slots = None return ModelOutput(f_backbone_patch, f_global, f_slots, p_global, p_slots) if self.architecture == "backbone(-global_fn-global_proj)-obj_fn-obj_proj": f_global = self.global_fn(f_backbone_pool) p_global = self.global_proj(f_global) # q_objs: [B S C] q_objs = self._get_q_objs(f_backbone_patch) # f_slots, p_slots: [B S C] f_slots = self.obj_fn(q_objs, f_backbone_patch) p_slots = self.obj_proj(f_slots) return ModelOutput(f_backbone_patch, f_global, f_slots, p_global, p_slots) if self.architecture == "backbone-obj_fn(-global_fn-global_proj)-obj_proj": # q_objs: [B S C] q_objs = self._get_q_objs(f_backbone_patch) # f_slots, p_slots: [B S C] f_slots = self.obj_fn(q_objs, f_backbone_patch) p_slots = self.obj_proj(f_slots) f_global = self.global_fn(f_slots.mean(dim=1)) p_global = self.global_proj(f_global) return ModelOutput(f_backbone_patch, f_global, f_slots, p_global, p_slots) raise ValueError(self.architecture)
def _get_q_objs(self, f_backbone_patch): B, N, C = f_backbone_patch.shape if isinstance(self.obj_queries, SampledObjectTokens): return self.obj_queries(B) if isinstance(self.obj_queries, LearnedObjectTokens): return self.obj_queries().expand(B, -1, -1) if isinstance( self.obj_queries, (KmeansCosineObjectTokens, KmeansEuclideanObjectTokens) ): return self.obj_queries(f_backbone_patch) raise TypeError(type(self.obj_queries))
[docs]@contextmanager def forward_with_attns(model: Model) -> Generator[Dict[str, Tensor], None, None]: """Context manager to set and remove hooks that collect attention maps. Example: Usage:: >>> with forward_with_attns(model) as attns: >>> outputs = model(inputs) >>> print(attns.items()) Args: model: Model where the attention hooks will be set. Yields: A dictionary that will be populated after the forward pass. """ module_to_name: Dict[nn.Module, str] = { # nn.Module: name } attns: Dict[str, Tensor] = { # backbone.attn_blocks.*: [AB heads Q K] # obj_fn.slot_attn.*: [AB S K] # obj_fn.attn_blocks.*.self_attn: [AB heads S S] # obj_fn.attn_blocks.*.cross_attn: [AB heads S K] } def normal_hook(m, inputs): (a,) = inputs attns[module_to_name[m]] = a.detach() def slot_hook(m, inputs): ((a, iter_idx),) = inputs attns[f"{module_to_name[m]}.slot_attn.{iter_idx}"] = a.detach() handles: List[RemovableHandle] = [] for name, module in model.named_modules(): if isinstance(module, (SelfAttention, CrossAttention)): module_to_name[module.attn_drop] = name handle = module.attn_drop.register_forward_pre_hook(normal_hook) handles.append(handle) elif isinstance(module, SlotAttention): module_to_name[module.dot_prod_softmax] = name handle = module.dot_prod_softmax.register_forward_pre_hook(slot_hook) handles.append(handle) try: yield attns finally: for handle in handles: handle.remove()