Source code for osc.viz.segmentation

from typing import Tuple

import numpy as np
import scipy.optimize
import torch
from einops import rearrange
from torch import Tensor

[docs]def match_segmentation_masks( masks: Tensor, preds: Tensor, preds_thres: Tensor ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Match predicted and ground-truth segmentation masks. Args: masks: binary ground-truth masks ``[B T H W]`` preds: float prediction masks ``[B S H W]`` for DICE matching preds_thres: binarized prediction masks ``[B S H W]`` for IoU matching Returns: Tuple of ``(iou_idx, iou_val, dice_idx, dice_val)``. IoU indexes indicate that the best match for ``masks[b, t]`` is ``preds_thres[b, iou_idx[b, t]]`` with an IoU of ``iou_val[b, t]``. DICE indexes indicate that the best match for ``masks[b, t]`` is ``preds[b, dice_idx[b, t]]`` with a DICE score of ``dice_val[b, t]``. """ assert masks.dtype == torch.bool B = masks.shape[0] device = masks.device masks = rearrange(masks, "B T H W -> B T 1 (H W)") preds = rearrange(preds, "B S H W -> B 1 S (H W)") preds_thres = rearrange(preds_thres, "B S H W -> B 1 S (H W)") # iou [B T S] intersection = torch.logical_and(masks, preds_thres).float().sum(-1) union = torch.logical_or(masks, preds_thres).float().sum(-1) iou = intersection.div_(union.clamp_min_(1e-8)) iou = iou.cpu().numpy() # iou_idx, iou_val: [B T] iou_idx = [] iou_val = [] for b in range(B): rows, cols = scipy.optimize.linear_sum_assignment(iou[b], maximize=True) iou_val.append(iou[b, rows, cols]) iou_idx.append(cols[np.argsort(rows)]) iou_idx = torch.from_numpy(np.stack(iou_idx, axis=0)).to(device) iou_val = torch.from_numpy(np.stack(iou_val, axis=0)).to(device) # dice: [B T S] masks = masks.float() intersection = torch.sum(masks * preds, dim=-1) union = torch.sum(masks, dim=-1) + torch.sum(preds, dim=-1) dice = intersection.mul_(2).div_(union.clamp_min_(1e-8)) dice = dice.cpu().numpy() # dice_idx, dice_val: [B T] dice_idx = [] dice_val = [] for b in range(masks.shape[0]): rows, cols = scipy.optimize.linear_sum_assignment(dice[b], maximize=True) dice_val.append(dice[b, rows, cols]) dice_idx.append(cols[np.argsort(rows)]) dice_idx = torch.from_numpy(np.stack(dice_idx, axis=0)).to(device) dice_val = torch.from_numpy(np.stack(dice_val, axis=0)).to(device) return iou_idx, iou_val, dice_idx, dice_val