osc.loss_objects.matching_contrastive_loss_per_img

matching_contrastive_loss_per_img(slots, temperature=1.0, reduction='mean')[source]

Contrastive object-wise loss, only between corresponding images.

The K x of the i-th image are matched with the x of the i+B-th image and vice versa. For each slot, the contrastive image considers the matching slot in the corresponding image as positive. The other K-1 x in the corresponding image as negatives, as well as the other K-1 x in the original image. Slots are only matched between one image and its augmented version, never within the same image and never with other images.

If all slot embeddings collapse to the same value, the loss will be log(2K-2).

Parameters
  • slots (Tensor) – [2B, K, C] tensor of projected image features

  • temperature (float) – temperature scaling

  • reduction (str) – ‘mean’, ‘sum’, or ‘none’

Return type

Tensor

Returns

Scalar loss over all samples and x if reduction is ‘mean’ or ‘sum’. A vector 2BK of losses if reduction is ‘none’

Example

A batch of B=4 images, augmented twice, each with K=3 x. The X represent positive matching targets for the cross entropy loss, the . represent negatives included in the loss (only from the augmented image and the image itself, but not the slot itself):

                        aug_0                    aug_1
               -----------------------  -----------------------
                 0     1     2     3      0     1     2     3
      |       [  . .|     |     |     ||. . X|     |     |     ]
      | img_0 [.   .|     |     |     ||. X .|     |     |     ]
      |       [. .  |     |     |     ||X . .|     |     |     ]
      |       [-----+-----+-----+------------+-----+-----+-----]
      |       [     |  . .|     |     ||     |X . .|     |     ]
      | img_1 [     |.   .|     |     ||     |. X .|     |     ]
      |       [     |. .  |     |     ||     |. . X|     |     ]
aug_0 |       [-----+-----+-----+------------+-----+-----+-----]
      |       [     |     |  . .|     ||     |     |. . X|     ]
      | img_2 [     |     |.   .|     ||     |     |X . .|     ]
      |       [     |     |. .  |     ||     |     |. X .|     ]
      |       [-----+-----+-----+------------+-----+-----+-----]
      |       [     |     |     |  . .||     |     |     |. . X]
      | img_3 [     |     |     |.   .||     |     |     |. X .]
      |       [     |     |     |. .  ||     |     |     |X . .]
              [=====|=====|=====|============|=====|=====|=====]
      |       [. . X|     |     |     ||  . .|     |     |     ]
      | img_0 [. X .|     |     |     ||.   .|     |     |     ]
      |       [X . .|     |     |     ||. .  |     |     |     ]
      |       [-----+-----+-----+------------+-----+-----+-----]
      |       [     |X . .|     |     ||     |  . .|     |     ]
      | img_1 [     |. X .|     |     ||     |.   .|     |     ]
      |       [     |. . X|     |     ||     |. .  |     |     ]
aug_1 |       [-----+-----+-----+------------+-----+-----+-----]
      |       [     |     |. X .|     ||     |     |  . .|     ]
      | img_2 [     |     |. . X|     ||     |     |.   .|     ]
      |       [     |     |X . .|     ||     |     |. .  |     ]
      |       [-----+-----+-----+------------+-----+-----+-----]
      |       [     |     |     |. . X||     |     |     |  . .]
      | img_3 [     |     |     |. X .||     |     |     |.   .]
      |       [     |     |     |X . .||     |     |     |. .  ]