"""
Visualization of global losses.
"""
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from matplotlib.ticker import PercentFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch import Tensor
from osc.utils import cos_pairwise
from .utils import subplots_grid
[docs]def viz_contrastive_loss_global(global_feats: Tensor, loss: float) -> Figure:
"""Visualize global contrastive loss between 2B images.
Args:
global_feats: a tensor with shape [2B, C]
loss: loss value to show on the plot
Returns:
A figure with two axes.
"""
A = 2
B = global_feats.shape[0] // A
# cos, prob: [AB AB]
cos = cos_pairwise(global_feats.detach().reshape(A * B, -1))
prob = (
torch.clone(cos)
.fill_diagonal_(-torch.inf)
.softmax(dim=-1)
.fill_diagonal_(np.nan)
.cpu()
.numpy()
)
acc = np.mean(np.nanargmax(prob, axis=1) == np.roll(np.arange(A * B), B))
cos = cos.fill_diagonal_(np.nan).cpu().numpy()
fig, axs = subplots_grid(1, 2, ax_height_inch=6, sharex=True, sharey=True)
ax = axs[0]
img = ax.imshow(cos)
fig.colorbar(img, ax=ax)
ax.scatter(np.nanargmax(cos, axis=1), np.arange(2 * B), color="red")
ax.set_title(
f"Pairwise cos of global projections\n"
f"Min {np.nanmin(cos):.3f} Max {np.nanmax(cos):.3f} Loss {loss:.4f}"
)
ax = axs[1]
img = ax.imshow(prob)
fig.colorbar(img, ax=ax, format=PercentFormatter(1.0))
ax.scatter(np.nanargmax(prob, axis=1), np.arange(2 * B), color="red")
ax.set_title(
f"Match probs of global projections\n"
f"Min {np.nanmin(prob):.1%} Max {np.nanmax(prob):.1%} Acc {acc:.2%}"
)
for ax in axs.flat:
ax.axhline(B - 0.5, color="black", lw=4, ls="--")
ax.axvline(B - 0.5, color="black", lw=4, ls="--")
ax.set_xticks(np.arange(2 * B))
ax.set_xticklabels(np.tile(np.arange(B), 2))
ax.set_xlabel("Image idx")
ax = axs[0]
ax.set_yticks(np.arange(2 * B))
ax.set_yticklabels(np.tile(np.arange(B), 2))
ax.set_ylabel("Image idx")
fig.tight_layout()
return fig
[docs]def viz_contrastive_loss_global_probs(
global_feats: Tensor, temp: float, loss: float
) -> Figure:
"""Visualize global contrastive loss between 2B images, probabilities only.
Args:
global_feats: a tensor with shape [2B, C]
temp: loss temperature
loss: loss value to show on the plot
Returns:
A figure with a single plot of matching probabilities.
"""
A = 2
B = global_feats.shape[0] // A
# cos, prob: [AB AB]
cos = cos_pairwise(global_feats.detach().reshape(A * B, -1))
prob = (
cos.div_(temp)
.fill_diagonal_(-torch.inf)
.softmax(dim=-1)
.fill_diagonal_(np.nan)
.cpu()
.numpy()
)
acc = np.mean(np.nanargmax(prob, axis=1) == np.roll(np.arange(A * B), B))
fig, ax = plt.subplots(1, 1, figsize=(1.7 * 2 * B, 1.7 * 2 * B))
img = ax.imshow(prob)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="1%", pad=0.1)
fig.colorbar(img, cax=cax, format=PercentFormatter(1.0))
ax.scatter(np.nanargmax(prob, axis=1), np.arange(2 * B), color="red")
ax.set_title(
f"Match probs of global projections (temp {temp:.2f})\n"
f"Min {np.nanmin(prob):.1%} Max {np.nanmax(prob):.1%} "
f"Acc {acc:.2%} Loss {loss:.4f}"
)
ax.axhline(B - 0.5, color="black", lw=4, ls="-")
ax.axvline(B - 0.5, color="black", lw=4, ls="-")
ax.set_xticks(np.arange(2 * B))
ax.set_xticklabels(np.tile(np.arange(B), 2))
ax.set_xlabel("Image idx")
ax.set_yticks(np.arange(2 * B))
ax.set_yticklabels(np.tile(np.arange(B), 2))
ax.set_ylabel("Image idx")
fig.tight_layout()
return fig