Source code for osc.utils

"""
Collection of utilities.
"""

import inspect
import logging
import random
import signal
from contextlib import AbstractContextManager
from pathlib import Path
from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import tabulate
import tensorflow as tf
import torch
from torch import Tensor

log = logging.getLogger(__name__)

ImgSizeHW = Tuple[int, int]
ImgMean = Tuple[float, float, float]
ImgStd = Tuple[float, float, float]





@torch.jit.script
def l2_normalize_(a: Tensor) -> Tensor:
    """L2 normalization in-place along the last dimension.

    Args:
        a: [N, C] tensor to normalize.

    Returns:
        The input tensor with normalized rows.
    """
    norm = torch.linalg.vector_norm(a, dim=-1, keepdim=True)
    return a.div_(norm.clamp_min_(1e-10))


@torch.jit.script
def l2_normalize(a: Tensor) -> Tensor:
    """L2 normalization along the last dimension.

    Args:
        a: [..., C] tensor to normalize.

    Returns:
        A new tensor containing normalized rows.
    """
    norm = torch.linalg.vector_norm(a, dim=-1, keepdim=True)
    return a / norm.clamp_min(1e-10)


@torch.jit.script
def cos_pairwise(a: Tensor, b: Optional[Tensor] = None) -> Tensor:
    """Cosine between all pairs of entries in two tensors.

    Args:
        a: [*N, C] tensor, where ``*N`` can be any number of leading dimensions.
        b: [*M, C] tensor, where ``*M`` can be any number of leading dimensions.
            Defaults to ``a`` if missing.

    Returns:
        [*N, *M] tensor of cosine values.
    """
    a = l2_normalize(a)
    b = a if b is None else l2_normalize(b)
    N = a.shape[:-1]
    M = b.shape[:-1]
    a = a.flatten(end_dim=-2)
    b = b.flatten(end_dim=-2)
    cos = torch.einsum("nc,mc->nm", a, b)
    return cos.reshape(N + M)


[docs]def batches_per_epoch(num_samples: int, batch_size: int, drop_last: bool) -> int: """Compute the number of batches in one epoch according to batch size and drop behavior. Args: num_samples: batch_size: drop_last: Returns: The number of batches. """ if drop_last: return int(np.floor(num_samples / batch_size)) else: return int(np.ceil(num_samples / batch_size))
[docs]def seed_everything(seed: int): """Seed python, torch, tensorflow, and numpy""" random.seed(seed) torch.manual_seed(seed) tf.random.set_seed(seed) np.random.seed(np.cast[np.uint32](seed)) log.info("All random seeds: %d", seed)
[docs]def latest_checkpoint(run_dir: Union[Path, str]) -> Path: """Find the checkpoint with the highest numerical epoch. Args: run_dir: the search path containing ``checkpoint.*.pth`` files. Returns: Path to the checkpoint. """ glob = Path(run_dir).glob("checkpoint.*.pth") glob = sorted(glob, key=lambda p: int(p.name.split(".")[1])) return glob[-1]
@torch.jit.script def normalize_sum_to_one(tensor: Tensor, dim: int = -1) -> Tensor: """Normalize tensor so that it sums to one over the last dimension.""" return tensor / tensor.sum(dim=dim, keepdim=True).clamp_min(1e-8)
[docs]def fill_diagonal_(x: Tensor, value: float = -torch.inf) -> Tensor: """Fill diagonal of torch tensor, batched and in-place. Args: x: ``[..., N, M]`` tensor with leading batch dimensions. value: value to fill the diagonal with. Returns: The same ``[..., N, M]`` input tensor with diagonal entries ``x[..., i, i]`` replaced. """ N, M = x.shape[-2:] mask = torch.eye(N, M, dtype=torch.bool, device=x.device) return x.masked_fill_(mask, value)
[docs]@tf.function def ravel_multi_index_tf( multi_index: Tuple[tf.Tensor, ...], dims: Tuple[int, ...] ) -> tf.Tensor: """Same as :func:``np.ravel_multi_index`` with no checks. Args: multi_index: a tuple of index tensors, one per dimension dims: the shape of the tensor. The number of dimensions should be the same as the number of index tensors in ``multi_index``. Also, ``all(multi_index[i] < dims[i])`` is expected to hold for all ``i``. Example: How to use: >>> ravel_multi_index_tf( >>> ( >>> tf.constant([4, 3, 3, 2]), >>> tf.constant([0, 1, 5, 6]), >>> tf.constant([0, 1, 2, 0]), >>> ), >>> dims=(5, 7, 3) >>> ) [84, 67, 80, 60] Returns: A 1D tensor of flat indexes. """ strides = tf.math.cumprod(tf.constant([*dims, 1]), reverse=True)[1:] multi_index = tf.stack(multi_index, axis=0) result = multi_index * strides[:, None] return tf.reduce_sum(result, axis=0)
[docs]def to_dotlist(d, prf: Tuple[str, ...] = tuple()) -> Iterator[Tuple[str, Any]]: """Iterate over a flattened dict using dot-separated keys. Args: d: a hierarchical dictionary with strings as keys prf: tuple of prefixes used for recursion Notes: The type of ``d`` can be defined as ``StringDict``, but it clashes with Sphynx: ``StringDict = Mapping[str, Union[Any, StringDict]]``. Examples: Flatten a dict: >>> print(*to_dotlist({'a': {'b': {}, 'c': 1, 'd': [2]}, 'x': []}), sep="\\n") ('a.b', {}) ('a.c', 1) ('a.d', [2]) ('x', []) Yields: Tuples of dot-separated keys and values. Call :func:``dict`` on the returned iterator to build a flat dict. """ for k, v in d.items(): if isinstance(v, dict) and len(v) > 0: yield from to_dotlist(v, prf=(*prf, k)) else: yield ".".join((*prf, k)), v
[docs]class SigIntCatcher(AbstractContextManager): """Context manager to gracefully handle SIGINT or KeyboardInterrupt. Example: Gracefully terminate a loop: >>> with SigIntCatcher() as should_stop: >>> for i in range(1000): >>> print(f"Step {i}...") >>> time.sleep(5) >>> print("Done") >>> if should_stop: >>> break """ def __init__(self): self._interrupted = None self._old_handler = None def __bool__(self): return self._interrupted def __enter__(self): self._old_handler = signal.getsignal(signal.SIGINT) self._interrupted = False signal.signal(signal.SIGINT, self._handler) return self def __exit__(self, exc_type, exc_val, exc_tb): signal.signal(signal.SIGINT, self._old_handler) self._interrupted = None self._old_handler = None # noinspection PyUnusedLocal def _handler(self, signum, frame): if not self._interrupted: log.warning( "Interrupted, wait for graceful termination, repeat to raise exception" ) self._interrupted = True else: log.warning("Interrupted again, raising exception") raise KeyboardInterrupt()
[docs]class StepCounter(object): """A simple counter to keep track of update steps.""" def __init__(self): self.steps = 0 def __int__(self): return self.steps
[docs] def step(self): """Increment by 1.""" self.steps += 1
def __str__(self): return f"{self.__class__.__name__}({int(self)})"
[docs]class AverageMetric(object): """Keep track of the average value of a metric across batches""" def __init__(self): self.total = 0.0 self.count = 0
[docs] def update(self, values: Tensor): """Update. Args: values: a 1D tensor of per-sample metric values. """ self.total += values.detach().sum().item() self.count += values.numel()
[docs] def compute(self) -> float: """Compute. Returns: The current average. """ return self.total / self.count