"""
Visualization of object losses.
"""
import numpy as np
import scipy.optimize
import torch
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from matplotlib.ticker import PercentFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch import Tensor
from osc.utils import cos_pairwise
from .utils import subplots_grid
[docs]def match_objects(obj_feats: Tensor) -> np.ndarray:
"""Compute matches between 2B images with S objects each.
Args:
obj_feats: object feature tensor of shape ``[2B S C]``,
where ``S`` is the number of objects per image.
Returns:
An array of matches (one match per row), length ``2BS``
"""
B, S, C = obj_feats.shape
B //= 2
s0, s1 = obj_feats.detach().reshape(2, B, S, C).unbind(dim=0)
cos = torch.einsum("ikc,ilc->ikl", s0, s1).cpu().numpy()
targets = np.zeros(2 * B * S, dtype=int)
for b in range(B):
# First output is a vector of sorted row idxs [0, 1, ..., S]
_, cols = scipy.optimize.linear_sum_assignment(cos[b, :, :], maximize=True)
targets[b * S : (b + 1) * S] = (B + b) * S + cols
targets[(B + b) * S : (B + b + 1) * S] = b * S + np.argsort(cols)
return targets
[docs]def viz_contrastive_loss_objects(obj_feats: Tensor, loss: float) -> Figure:
"""Visualize contrastive loss between 2B images with S objects each.
Args:
obj_feats: object feature tensor of shape ``[2B S C]``,
where ``S`` is the number of objects per image.
loss: loss value to display on the plots
Returns:
A figure with two plots.
"""
B = obj_feats.shape[0] // 2
S = obj_feats.shape[1]
cos = cos_pairwise(obj_feats.detach().reshape(2 * B * S, -1))
prob = (
torch.clone(cos)
.fill_diagonal_(-torch.inf)
.softmax(dim=-1)
.fill_diagonal_(np.nan)
.cpu()
.numpy()
)
cos = cos.fill_diagonal_(np.nan).cpu().numpy()
matches = match_objects(obj_feats)
fig, axs = subplots_grid(1, 2, ax_height_inch=2 * B, sharex=True, sharey=True)
ax = axs[0]
img = ax.imshow(cos)
fig.colorbar(img, ax=ax)
ax.set_title(
f"Pairwise cos of slot projections\n"
f"Min {np.nanmin(cos):.3f} Max {np.nanmax(cos):.3f} Loss {loss:.4f}"
)
ax.scatter(matches, np.arange(2 * B * S), color="red", s=10)
ax = axs[1]
img = ax.imshow(prob)
fig.colorbar(img, ax=ax, format=PercentFormatter(1.0))
ax.set_title(
f"Match probs of slot projections\n"
f"Min {np.nanmin(prob):.1%} Max {np.nanmax(prob):.1%}"
)
ax.scatter(matches, np.arange(2 * B * S), color="red", s=10)
for ax in axs.flat:
ax.axhline(B * S - 0.5, color="black", lw=4)
ax.axvline(B * S - 0.5, color="black", lw=4)
for line in range(S, 2 * B * S, S):
ax.axhline(line - 0.5, color="black", lw=2)
ax.axvline(line - 0.5, color="black", lw=2)
ax.set_xlabel("Image idx | slot idx")
ax.set_xticks(np.arange(2 * B * S))
ax.set_xticklabels(
2 * [k if k > 0 else f"img {b} | {k}" for b in range(B) for k in range(S)],
rotation=90,
fontdict={"fontsize": "small"},
)
ax = axs[0]
ax.set_yticks(np.arange(2 * B * S))
ax.set_yticklabels(
2 * [s if s > 0 else f"img {b} | {s}" for b in range(B) for s in range(S)],
fontdict={"fontsize": "small"},
)
ax.set_ylabel("Image idx | slot idx")
fig.set_facecolor("white")
fig.tight_layout()
return fig
[docs]def viz_contrastive_loss_objects_probs(
obj_feats: Tensor, temp: float, loss: float
) -> Figure:
"""Visualize contrastive loss between 2B images with S objects each, probabilities only.
Args:
obj_feats: object feature tensor of shape ``[2B S C]``,
where ``S`` is the number of objects per image.
temp: loss temperature
loss: loss value to display on the plots
Returns:
A figure with a single plot of matching probabilities.
"""
B = obj_feats.shape[0] // 2
S = obj_feats.shape[1]
cos = cos_pairwise(obj_feats.detach().reshape(2 * B * S, -1))
prob = (
cos.div_(temp)
.fill_diagonal_(-torch.inf)
.softmax(dim=-1)
.fill_diagonal_(np.nan)
.cpu()
.numpy()
)
fig, ax = plt.subplots(1, 1, figsize=(1.7 * 2 * B, 1.7 * 2 * B))
img = ax.imshow(prob)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="1%", pad=0.1)
fig.colorbar(img, cax=cax, format=PercentFormatter(1.0))
matches = match_objects(obj_feats)
ax.scatter(matches, np.arange(2 * B * S), color="red", s=10)
ax.set_title(
f"Match probs of slot projections (temp {temp:.2f})\n"
f"Min {np.nanmin(prob):.1%} Max {np.nanmax(prob):.1%} Loss {loss:.4f}"
)
ax.axhline(B * S - 0.5, color="black", lw=4)
ax.axvline(B * S - 0.5, color="black", lw=4)
for line in range(S, 2 * B * S, S):
ax.axhline(line - 0.5, color="black", lw=2)
ax.axvline(line - 0.5, color="black", lw=2)
ax.set_xlabel("Image idx | slot idx")
ax.set_xticks(np.arange(2 * B * S))
ax.set_xticklabels(
2 * [k if k > 0 else f"img {b} | {k}" for b in range(B) for k in range(S)],
rotation=90,
fontdict={"fontsize": "small"},
)
ax.set_ylabel("Image idx | slot idx")
ax.set_yticks(np.arange(2 * B * S))
ax.set_yticklabels(
2 * [s if s > 0 else f"img {b} | {s}" for b in range(B) for s in range(S)],
fontdict={"fontsize": "small"},
)
fig.set_facecolor("white")
fig.tight_layout()
return fig