Source code for osc.models.embeds

"""
Positional embeddings and query tokens for objects.
"""

import numpy as np
import timm.models
import torch
from torch import Tensor
from torch import nn as nn

from osc.utils import l2_normalize


[docs]class PositionalEmbedding(nn.Module): def __init__(self, num_embeddings, embed_dim, dropout=0.0): super(PositionalEmbedding, self).__init__() self.embed = nn.Parameter(torch.zeros((num_embeddings, embed_dim))) self.drop = nn.Dropout(p=dropout) self.init_weights() def init_weights(self): timm.models.layers.trunc_normal_(self.embed, std=0.02)
[docs] def forward(self, x): return self.drop(x + self.embed)
[docs]class LearnedObjectTokens(nn.Module): """Fixed number of learnable object tokens, shape ``[S C]``""" def __init__(self, embed_dim: int, num_objects: int): super(LearnedObjectTokens, self).__init__() self.tokens = nn.Parameter(torch.zeros((num_objects, embed_dim))) self.init_weights() def init_weights(self): timm.models.layers.trunc_normal_(self.tokens, std=0.02)
[docs] def forward(self): return self.tokens
[docs]class SampledObjectTokens(nn.Module): """Dynamically sample ``S`` object tokens from ``K`` gaussian components. See :func:``forward``. """ def __init__(self, embed_dim: int, num_objects: int, num_components: int = 1): """ Args: embed_dim: embedding dimension ``D`` num_objects: default number of objects tokens to sample ``S`` num_components: number of gaussian components to learn ``K``. If greater than one, each object token is drawn from one of ``K`` gaussians chosen uniformly at random with replacement. """ super(SampledObjectTokens, self).__init__() self.mu = nn.Parameter(torch.zeros(num_components, embed_dim)) self.log_sigma = nn.Parameter(torch.zeros(num_components, embed_dim)) self.init_weights() self.num_objects = num_objects
[docs] def init_weights(self): """Glorot/Xavier uniform initialization""" embed_dim = self.mu.shape[-1] limit = np.sqrt(6.0 / embed_dim) nn.init.uniform_(self.mu, -limit, +limit) nn.init.uniform_(self.log_sigma, -limit, +limit)
[docs] def forward(self, batch_size: int, num_objects: int = None, seed: int = None): """Forward. Args: batch_size: number of independent batches ``B`` num_objects: number of object tokens to sample ``S`` seed: optional random seed Returns: A tensor of sampled object tokens with shape ``[B S D]`` """ device = self.mu.device B = batch_size S = num_objects if num_objects is not None else self.num_objects K, D = self.mu.shape # Pick which random generator to use if seed is not None: rng = torch.Generator(device).manual_seed(seed) else: if self.training: # Use global rng rng = None else: # Use fixed rng rng = torch.Generator(device).manual_seed(18327415066407224732) # TODO: during val slots should be the same for all images, # but that makes val loss artificially low so the training # behavior is always enabled by doing `if True or` in two places # Sample from from a normal gaussian if True or self.training: # Sample B*S independent vectors of length D slots = torch.randn(B, S, D, device=device, generator=rng) else: # Sample S independent vectors of length D, then reuse them for all B images slots = torch.randn(1, S, D, device=device, generator=rng).expand(B, -1, -1) # If the parameters describe a single component, use that single mu and sigma. # Otherwise uniformly sample S (mu, sigma) pairs out of the K learned pairs. if K == 1: mu = self.mu log_sigma = self.log_sigma else: if True or self.training: # Choose S components for each image in the batch independently idx = torch.randint(0, K, size=(B * S,), device=device, generator=rng) else: # Choose the same S components for all images in the batch idx = torch.randint(0, K, size=(S,), device=device, generator=rng) mu = self.mu[idx, :].reshape(-1, S, D).expand(B, S, D) log_sigma = self.log_sigma[idx, :].reshape(-1, S, D).expand(B, S, D) # Combine learned parameters (mu, sigma) and normal-distributed samples return mu + log_sigma.exp() * slots
[docs]class KmeansCosineObjectTokens(nn.Module): def __init__(self, num_objects: int): super(KmeansCosineObjectTokens, self).__init__() self.max_iters = 20 self.tol = 1e-6 self.num_objects = num_objects
[docs] def forward(self, x: Tensor, num_objects: int = None, seed=None): """K-means clustering to initialize object tokens. Args: x: feature tensor, shape``[B N C]``. It will be L2-normalized internally. num_objects: number of objects seed: int seed, leave empty to use numpy default generator Returns: A ``[B K C]`` tensor of centroid features. """ if num_objects is None: num_objects = self.num_objects x = l2_normalize(x) return torch_kmeans_cosine(x, num_objects, seed, self.max_iters, self.tol)
[docs]class KmeansEuclideanObjectTokens(nn.Module): def __init__(self, num_objects: int): super(KmeansEuclideanObjectTokens, self).__init__() self.max_iters = 20 self.tol = 1e-6 self.num_objects = num_objects
[docs] def forward(self, x: Tensor, num_objects: int = None, seed=None): """K-means clustering to initialize object tokens. Args: x: feature tensor, shape``[B N C]`` num_objects: number of objects seed: int seed, leave empty to use numpy default generator Returns: A ``[B K C]`` tensor of centroid features. """ if num_objects is None: num_objects = self.num_objects return torch_kmeans_euclidean(x, num_objects, seed, self.max_iters, self.tol)
[docs]@torch.no_grad() def torch_kmeans_euclidean( x: Tensor, num_clusters: int, seed: int = None, max_iters=10, tol=1e-8 ): """Batched k-means for PyTorch with Euclidean distance. Args: x: tensor containing ``B`` batches of ``N`` ``C``-dimensional tensors each, shape [B N C] num_clusters: number of clusters ``K`` seed: int seed for centroid initialization, leave empty to use numpy default generator max_iters: max number of iterations tol: tolerance for early termination Returns: A ``[B K C]`` tensor of centroids. """ B, N, C = x.shape K = num_clusters # [B K] indexes of the initial centroids rng = np.random if seed is None else np.random.default_rng(seed) idx = rng.choice(N, size=K, replace=False) idx = torch.from_numpy(idx).expand(B, -1).to(x.device) # [B K C] gather centroids from x using the sampled indexes centroids = x.gather(dim=1, index=idx[:, :, None].expand(-1, -1, C)) for i in range(max_iters): # [B N K] mean squared error between each point and each centroid mse = (centroids[:, None, :, :] - x[:, :, None, :]).pow_(2).sum(dim=-1) # [B N] assign the closest centroid to each sample idx = mse.argmin(dim=-1) # [B K] how many samples are assigned to each centroid? counts = idx.new_zeros(B, K).scatter_add_( dim=1, index=idx, src=idx.new_ones(B, N) ) # [B K C] for each centroid, sum all samples assigned to it, divide by count. # If centroid has zero point assigned to it, leave the old centroid value new_centr = ( torch.zeros_like(centroids) .scatter_add_(dim=1, index=idx[:, :, None].expand(-1, -1, C), src=x) .div_(counts[:, :, None]) ) new_centr = torch.where(counts[:, :, None] == 0, centroids, new_centr) # Compute tolerance and exit early error = centroids.sub_(new_centr).abs_().max() centroids = new_centr if error.item() < tol: break return centroids
[docs]@torch.no_grad() def torch_kmeans_cosine( x: Tensor, num_clusters: int, seed: int = None, max_iters=10, tol=1e-8 ): """Batched k-means for PyTorch with cosine distance. Args: x: tensor containing ``B`` batches of ``N`` ``C``-dimensional tensors each, shape [B N C]. The tensors should be L2-normalized along the ``C`` dimension. num_clusters: number of clusters ``K`` seed: int seed for centroid initialization, leave empty to use numpy default generator max_iters: max number of iterations tol: tolerance for early termination Returns: A ``[B K C]`` tensor of centroids. """ B, N, C = x.shape K = num_clusters # [B K] indexes of the initial centroids rng = np.random if seed is None else np.random.default_rng(seed) idx = torch.from_numpy(rng.choice(N, size=(B, K), replace=True)).to(x.device) # [B K C] gather centroids from x using the sampled indexes centroids = x.detach().gather(dim=1, index=idx[:, :, None].expand(-1, -1, C)) for i in range(max_iters): # [B N K] mean squared error between each point and each centroid. cos = torch.einsum("bnc, bkc -> bnk", x, centroids) # [B N] assign the closest centroid to each sample idx = cos.argmax(dim=-1) # [B K] how many samples are assigned to each centroid? counts = ( idx.new_zeros(B, K) .scatter_add_(dim=1, index=idx, src=idx.new_ones(B, N)) .clamp_min(1) ) # [B K C] for each centroid, sum all samples assigned to it, # divide by count and re-normalize new_centr = torch.zeros_like(centroids) new_centr.scatter_add_(dim=1, index=idx[:, :, None].expand(-1, -1, C), src=x) new_centr.div_(counts[:, :, None]) new_centr.div_(torch.linalg.vector_norm(new_centr, dim=-1, keepdim=True)) # Compute tolerance and exit early error = centroids.sub_(new_centr).abs_().max() centroids = new_centr if error.item() < tol: break return centroids