Source code for osc.data.utils

"""
Helper functions to bridge tensorflow image data to pytorch.

Functions:

- Image normalization back and forth
- Augmentation

    - Augmentations with ''stateless'' random seeds
    - Fixed augmentations
    - Double augmentation for contrastive training
"""

from typing import Callable, Protocol, Tuple

import tensorflow as tf
import torch
import torch.nn.functional
import torch.utils.data

from osc.utils import ImgMean, ImgSizeHW, ImgStd

from .random_resized_crop import random_resized_crop


[docs]@tf.function def normalize_tf(img: tf.Tensor, mean: ImgMean, std: ImgStd) -> tf.Tensor: """Normalize tf.float32 image with [..., H, W, C] channel order. Args: img: tf.float32 of shape [..., H, W, C] mean: tuple of floats shape [C] std: tuple of floats shape [C] Returns: Normalized image. """ mean = tf.convert_to_tensor(mean) std = tf.convert_to_tensor(std) return (img - mean) / std
[docs]def unnormalize_pt(img: torch.Tensor, mean: ImgMean, std: ImgStd) -> torch.Tensor: """Un-normalize torch.float32 image with [..., C, H, W] channel order. Args: img: torch.float32 of shape [..., C, H, W] mean: tuple of floats shape [C] std: tuple of floats shape [C] Returns: Un-normalized image with values in range [0, 1]. """ mean = torch.tensor(mean).to(img.device) std = torch.tensor(std).to(img.device) img = img * std[:, None, None] + mean[:, None, None] return torch.clip(img, 0, 1)
[docs]@tf.function def img_hwc_to_chw(img: tf.Tensor) -> tf.Tensor: """Image channels [..., H, W, C] -> [..., C, H, W]""" return tf.experimental.numpy.moveaxis(img, -1, -3)
[docs]@tf.function def augment_train( img: tf.Tensor, *, seed: tf.Tensor, crop_size: ImgSizeHW, crop_scale: Tuple[float, float], mean: ImgMean, std: ImgStd ) -> tf.Tensor: """Augment image for train/val. Pass different seeds to get different augmentations. Augmentations applied: - uint8 [0-255] -> float32 [0-1] - random flip - random resized crop (random aspect ratio, zoom factor, resize to ``crop_size``) - color jitter (brightness, saturation, hue) - normalization to given ``mean`` and ``std`` - pytorch channel order [H W C] -> [C H W] Args: img: tf.uint8 tensor seed: random seed used as base seed for all transformations crop_size: size of the image after cropping crop_scale: minimum and maximum crop scale for example ``(0.3, 1.0)`` mean: image mean for normalization std: image std for normalization """ img = tf.image.convert_image_dtype(img, tf.float32) seeds = tf.random.experimental.stateless_split(seed, num=5) img = tf.image.stateless_random_flip_left_right(img, seeds[0]) img = random_resized_crop(img, size=crop_size, scale=crop_scale, seed=seeds[1]) img = tf.image.stateless_random_brightness(img, 0.2, seeds[2]) img = tf.image.stateless_random_saturation(img, 0.9, 1.0, seeds[3]) img = tf.image.stateless_random_hue(img, 0.05, seeds[4]) img = normalize_tf(img, mean, std) img = img_hwc_to_chw(img) return img
[docs]@tf.function def augment_center_crop( img: tf.Tensor, *, crop: float, img_size: ImgSizeHW, mean: ImgMean, std: ImgStd ) -> tf.Tensor: """Deterministic center crop and processing as in :func:`augment_train` Args: img: TF uint8 tensor of shape [H W C] crop: img_size: mean: std: Returns: """ img = tf.image.convert_image_dtype(img, tf.float32) img = tf.image.central_crop(img, crop) img = tf.image.resize(img, img_size, method="bilinear", antialias=True) img = normalize_tf(img, mean, std) img = img_hwc_to_chw(img) return img
AugmentFn = Callable[ [ tf.Tensor, ], tf.Tensor, ]
[docs]def augment_twice(augment_fn0: AugmentFn, augment_fn1: AugmentFn): """Augment image twice with two different functions. Args: augment_fn0: callable, augment_fn0(img) augment_fn1: callable, augment_fn1(img) Returns: Callable, augment_twice(fn0, fn1)(img) -> (fn0(img), fn1(img)) """ @tf.function def twice(img: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: return augment_fn0(img), augment_fn1(img) return twice
[docs]class AugmentFnStateless(Protocol): """Function type signature for TF stateless random."""
[docs] def __call__(self, img: tf.Tensor, *, seed: tf.Tensor) -> tf.Tensor: ...
[docs]def wrap_with_seed( augment_fn: AugmentFnStateless, *, initial_seed: int = 0 ) -> AugmentFn: """Wrap stateless random functions by providing a new seed to each invocation. Args: augment_fn: callable, augment_fn(img, seed) initial_seed: initial seed for the generator. Returns: Callable that wraps ``augment_fn``. """ rng = tf.random.Generator.from_seed(initial_seed) @tf.function def wrapped(img: tf.Tensor) -> tf.Tensor: seed = rng.make_seeds(count=1)[:, 0] return augment_fn(img, seed=seed) return wrapped