Source code for osc.viz.loss_objects

"""
Visualization of object losses.
"""
import numpy as np
import scipy.optimize
import torch
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from matplotlib.ticker import PercentFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch import Tensor

from osc.utils import cos_pairwise

from .utils import subplots_grid


[docs]def match_objects(obj_feats: Tensor) -> np.ndarray: """Compute matches between 2B images with S objects each. Args: obj_feats: object feature tensor of shape ``[2B S C]``, where ``S`` is the number of objects per image. Returns: An array of matches (one match per row), length ``2BS`` """ B, S, C = obj_feats.shape B //= 2 s0, s1 = obj_feats.detach().reshape(2, B, S, C).unbind(dim=0) cos = torch.einsum("ikc,ilc->ikl", s0, s1).cpu().numpy() targets = np.zeros(2 * B * S, dtype=int) for b in range(B): # First output is a vector of sorted row idxs [0, 1, ..., S] _, cols = scipy.optimize.linear_sum_assignment(cos[b, :, :], maximize=True) targets[b * S : (b + 1) * S] = (B + b) * S + cols targets[(B + b) * S : (B + b + 1) * S] = b * S + np.argsort(cols) return targets
[docs]def viz_contrastive_loss_objects(obj_feats: Tensor, loss: float) -> Figure: """Visualize contrastive loss between 2B images with S objects each. Args: obj_feats: object feature tensor of shape ``[2B S C]``, where ``S`` is the number of objects per image. loss: loss value to display on the plots Returns: A figure with two plots. """ B = obj_feats.shape[0] // 2 S = obj_feats.shape[1] cos = cos_pairwise(obj_feats.detach().reshape(2 * B * S, -1)) prob = ( torch.clone(cos) .fill_diagonal_(-torch.inf) .softmax(dim=-1) .fill_diagonal_(np.nan) .cpu() .numpy() ) cos = cos.fill_diagonal_(np.nan).cpu().numpy() matches = match_objects(obj_feats) fig, axs = subplots_grid(1, 2, ax_height_inch=2 * B, sharex=True, sharey=True) ax = axs[0] img = ax.imshow(cos) fig.colorbar(img, ax=ax) ax.set_title( f"Pairwise cos of slot projections\n" f"Min {np.nanmin(cos):.3f} Max {np.nanmax(cos):.3f} Loss {loss:.4f}" ) ax.scatter(matches, np.arange(2 * B * S), color="red", s=10) ax = axs[1] img = ax.imshow(prob) fig.colorbar(img, ax=ax, format=PercentFormatter(1.0)) ax.set_title( f"Match probs of slot projections\n" f"Min {np.nanmin(prob):.1%} Max {np.nanmax(prob):.1%}" ) ax.scatter(matches, np.arange(2 * B * S), color="red", s=10) for ax in axs.flat: ax.axhline(B * S - 0.5, color="black", lw=4) ax.axvline(B * S - 0.5, color="black", lw=4) for line in range(S, 2 * B * S, S): ax.axhline(line - 0.5, color="black", lw=2) ax.axvline(line - 0.5, color="black", lw=2) ax.set_xlabel("Image idx | slot idx") ax.set_xticks(np.arange(2 * B * S)) ax.set_xticklabels( 2 * [k if k > 0 else f"img {b} | {k}" for b in range(B) for k in range(S)], rotation=90, fontdict={"fontsize": "small"}, ) ax = axs[0] ax.set_yticks(np.arange(2 * B * S)) ax.set_yticklabels( 2 * [s if s > 0 else f"img {b} | {s}" for b in range(B) for s in range(S)], fontdict={"fontsize": "small"}, ) ax.set_ylabel("Image idx | slot idx") fig.set_facecolor("white") fig.tight_layout() return fig
[docs]def viz_contrastive_loss_objects_probs( obj_feats: Tensor, temp: float, loss: float ) -> Figure: """Visualize contrastive loss between 2B images with S objects each, probabilities only. Args: obj_feats: object feature tensor of shape ``[2B S C]``, where ``S`` is the number of objects per image. temp: loss temperature loss: loss value to display on the plots Returns: A figure with a single plot of matching probabilities. """ B = obj_feats.shape[0] // 2 S = obj_feats.shape[1] cos = cos_pairwise(obj_feats.detach().reshape(2 * B * S, -1)) prob = ( cos.div_(temp) .fill_diagonal_(-torch.inf) .softmax(dim=-1) .fill_diagonal_(np.nan) .cpu() .numpy() ) fig, ax = plt.subplots(1, 1, figsize=(1.7 * 2 * B, 1.7 * 2 * B)) img = ax.imshow(prob) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="1%", pad=0.1) fig.colorbar(img, cax=cax, format=PercentFormatter(1.0)) matches = match_objects(obj_feats) ax.scatter(matches, np.arange(2 * B * S), color="red", s=10) ax.set_title( f"Match probs of slot projections (temp {temp:.2f})\n" f"Min {np.nanmin(prob):.1%} Max {np.nanmax(prob):.1%} Loss {loss:.4f}" ) ax.axhline(B * S - 0.5, color="black", lw=4) ax.axvline(B * S - 0.5, color="black", lw=4) for line in range(S, 2 * B * S, S): ax.axhline(line - 0.5, color="black", lw=2) ax.axvline(line - 0.5, color="black", lw=2) ax.set_xlabel("Image idx | slot idx") ax.set_xticks(np.arange(2 * B * S)) ax.set_xticklabels( 2 * [k if k > 0 else f"img {b} | {k}" for b in range(B) for k in range(S)], rotation=90, fontdict={"fontsize": "small"}, ) ax.set_ylabel("Image idx | slot idx") ax.set_yticks(np.arange(2 * B * S)) ax.set_yticklabels( 2 * [s if s > 0 else f"img {b} | {s}" for b in range(B) for s in range(S)], fontdict={"fontsize": "small"}, ) fig.set_facecolor("white") fig.tight_layout() return fig