osc.loss_objects.matching_contrastive_loss

matching_contrastive_loss(slots, temperature=1.0, reduction='mean')[source]

Contrastive object-wise loss, all vs. all.

The vectors in position i and i + B of projs must represent K embeddings of different augmentations of the same image. For each (i, i+B) image pair, the K embeddings are 1:1 matched using linear-sum assignment to produce the targets for a 2BK-1-classes classification problem. Except for the matching slot, all 2BK-2 x of all images are considered negative samples for the loss.

If all slot embeddings collapse to the same value, the loss will be log(2BK-1).

Worst case: if all embeddings collapse to the same value, the loss will be log(2BK-1).

Best case: if each image gets an embedding that is orthogonal to all others, the loss will be log(exp(1/t) + 2BK - 2) - 1/t.

Parameters
  • slots (Tensor) – [2B, K, C] tensor of projected image features

  • temperature (float) – temperature scaling

  • reduction (str) – ‘mean’, ‘sum’, or ‘none’

Return type

Tensor

Returns

Scalar loss over all samples and x if reduction is ‘mean’ or ‘sum’. A vector 2BK of losses if reduction is ‘none’

Example

A batch of B=4 images, augmented twice, each with K=3 x. The X represent positive matching targets for the cross entropy loss, the . represent negatives included in the loss (all except diagonal):

                        aug_0                    aug_1
               -----------------------  -----------------------
                 0     1     2     3      0     1     2     3
      |       [  . .|. . .|. . .|. . .||. . X|. . .|. . .|. . .]
      | img_0 [.   .|. . .|. . .|. . .||. X .|. . .|. . .|. . .]
      |       [. .  |. . .|. . .|. . .||X . .|. . .|. . .|. . .]
      |       [-----+-----+-----+------------+-----+-----+-----]
      |       [. . .|  . .|. . .|. . .||. . .|X . .|. . .|. . .]
      | img_1 [. . .|.   .|. . .|. . .||. . .|. X .|. . .|. . .]
      |       [. . .|. .  |. . .|. . .||. . .|. . X|. . .|. . .]
aug_0 |       [-----+-----+-----+------------+-----+-----+-----]
      |       [. . .|. . .|  . .|. . .||. . .|. . .|. . X|. . .]
      | img_2 [. . .|. . .|.   .|. . .||. . .|. . .|X . .|. . .]
      |       [. . .|. . .|. .  |. . .||. . .|. . .|. X .|. . .]
      |       [-----+-----+-----+------------+-----+-----+-----]
      |       [. . .|. . .|. . .|  . .||. . .|. . .|. . .|. . X]
      | img_3 [. . .|. . .|. . .|.   .||. . .|. . .|. . .|. X .]
      |       [. . .|. . .|. . .|. .  ||. . .|. . .|. . .|X . .]
              [=====|=====|=====|============|=====|=====|=====]
      |       [. . X|. . .|. . .|. . .||  . .|. . .|. . .|. . .]
      | img_0 [. X .|. . .|. . .|. . .||.   .|. . .|. . .|. . .]
      |       [X . .|. . .|. . .|. . .||. .  |. . .|. . .|. . .]
      |       [-----+-----+-----+------------+-----+-----+-----]
      |       [. . .|X . .|. . .|. . .||. . .|  . .|. . .|. . .]
      | img_1 [. . .|. X .|. . .|. . .||. . .|.   .|. . .|. . .]
      |       [. . .|. . X|. . .|. . .||. . .|. .  |. . .|. . .]
aug_1 |       [-----+-----+-----+------------+-----+-----+-----]
      |       [. . .|. . .|. X .|. . .||. . .|. . .|  . .|. . .]
      | img_2 [. . .|. . .|. . X|. . .||. . .|. . .|.   .|. . .]
      |       [. . .|. . .|X . .|. . .||. . .|. . .|. .  |. . .]
      |       [-----+-----+-----+------------+-----+-----+-----]
      |       [. . .|. . .|. . .|. . X||. . .|. . .|. . .|  . .]
      | img_3 [. . .|. . .|. . .|. X .||. . .|. . .|. . .|.   .]
      |       [. . .|. . .|. . .|X . .||. . .|. . .|. . .|. .  ]