Source code for osc.loss_global

from typing import Tuple

import numpy as np
import torch
import torch.nn.functional
from torch import Tensor

from osc.utils import cos_pairwise, l2_normalize


@torch.jit.script
def cosine_sim_loss(feats: Tensor, projs: Tensor, reduction: str = "mean") -> Tensor:
    """Cosine similarity loss for BYOL and SimSiam.

    For both `feats` and `projs`, the vectors in position ``i`` and ``i + B``
    must represent the embeddings of different augmentations of the same image.

    In BYOL, `feats` should come from an offline network updated with momentum.
    In SimSiam, `feats` comes from the same online network that computes `proj`.
    Detach (stop gradient) is always called on the `feats` embeddings.

    The loss is made symmetric w.r.t. the image augmentations.

    Examples:

        BYOL:

        >>> imgs = torch.rand(4, 3, 256, 256)
        >>> inputs = torch.cat([aug0(imgs), aug1(imgs)], dim=0)
        >>> feats_offline = backbone_offline(inputs)
        >>> feats = backbone(inputs)
        >>> projs = projection(feats)
        >>> cosine_sim_loss(feats_offline, projs)

        SimSiam:

        >>> imgs = torch.rand(4, 3, 256, 256)
        >>> inputs = torch.cat([aug0(imgs), aug1(imgs)], dim=0)
        >>> feats = backbone(inputs)
        >>> projs = projection(feats)
        >>> cosine_sim_loss(feats, projs)

    Args:
        feats: [2B, C] tensor of image features
        projs: [2B, C] tensor of projected image features
        reduction: 'mean', 'sum', or 'none'

    Returns:
        Scalar loss if reduction is 'mean' or 'sum'.
        A vector of ``2B`` losses if reduction is 'none'

    """
    B = feats.shape[0] // 2
    feats = l2_normalize(feats.detach()).roll(B, dims=0)
    projs = l2_normalize(projs)
    if reduction == "mean":
        # Same as: - (feats * projs).sum(dim=1).mean(dim=0)
        return torch.einsum("bc,bc->", feats, projs).div_(-2 * B)
    elif reduction == "sum":
        # Same as: - (feats * projs).sum(dim=1).sum(dim=0)
        return torch.einsum("bc,bc->", feats, projs).neg_()
    elif reduction == "none":
        # Same as: - (feats * projs).sum(dim=1)
        return torch.einsum("bc,bc->b", feats, projs).neg_()
    raise ValueError(reduction)


@torch.jit.script
def contrastive_loss(
    projs: Tensor, temperature: float = 1.0, reduction: str = "mean"
) -> Tensor:
    """Contrastive loss for SimCLR.

    The vectors in position ``i`` and ``i + B`` of `projs` must
    represent the embeddings of different augmentations of the same image.
    The loss can be thought as a classification problem with ``2B-1`` classes where
    the ``i``-th sample should be classified as the ``i+B``-th sample and viceversa.
    The loss is made symmetric w.r.t. the image augmentations.

    Worst case:
    if all embeddings collapse to the same value, the loss will be ``log(2B-1)``.

    Best case:
    if each image gets an embedding that is orthogonal to all others,
    the loss will be ``log(exp(1/t) + 2B - 2) - 1/t``.

    Examples:

        SimSiam:

        >>> imgs = torch.rand(4, 3, 256, 256)
        >>> inputs = torch.cat([aug0(imgs), aug1(imgs)], dim=0)
        >>> feats = backbone(inputs)
        >>> projs = projection(feats)
        >>> contrastive_loss(projs)

    Args:
        projs: [2B, C] tensor of projected image features
        temperature: temperature scaling
        reduction: 'mean', 'sum', or 'none'

    Returns:
        Scalar loss if reduction is 'mean' or 'sum'.
        A vector of ``2B`` losses if reduction is 'none'
    """
    B = projs.shape[0] // 2
    cos = cos_pairwise(projs).div_(temperature).fill_diagonal_(-torch.inf)  # [2B, 2B]
    target = torch.arange(2 * B, device=cos.device).roll(B)  # [2B]
    loss = torch.nn.functional.cross_entropy(cos, target, reduction=reduction)
    return loss


[docs]def contrastive_loss_best_worst( batch_size: int, temperature: float ) -> Tuple[float, float]: """Best and worst case for ``contrastive_loss()``.""" best = np.log(np.exp(1 / temperature) + 2 * batch_size - 2) - 1 / temperature worst = np.log(2 * batch_size - 1) return best, worst