Source code for osc.data.tfrecords

"""
Helper functions to serialize/deserialize TFRecords.
"""
from typing import Tuple

import tensorflow as tf

from osc.utils import ImgSizeHW


[docs]def serialize_image(image: tf.Tensor) -> bytes: """Serialize an image to be written to a TFRecord. Args: image: uint8 tensor of shape ``[H W 3]`` Example: When calling this function from a :mod:``tf.data`` pipeline, wrap it in a :func:``tf.py_function``:: >>> tf.py_function(serialize_image, (image,), tf.string) Returns: A byte string. """ feature = { "image": tf.train.Feature( bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(image).numpy()]) ) } example = tf.train.Example(features=tf.train.Features(feature=feature)) return example.SerializeToString()
[docs]@tf.function def deserialize_image(example: bytes, *, img_size: ImgSizeHW) -> tf.Tensor: """Deserialize an image as read from a TFRecord. Args: example: byte serialization of the image img_size: expected shape of the image ``(H, W)`` Returns: A uint8 tensor of shape ``[H W 3]`` """ example = tf.io.parse_single_example( example, {"image": tf.io.FixedLenFeature([], tf.string)} ) image = tf.io.parse_tensor(example["image"], tf.uint8) image = tf.ensure_shape(image, (*img_size, 3)) return image
[docs]def serialize_image_and_mask(image: tf.Tensor, mask: tf.Tensor): """Serialize an image and a mask to be written to a TFRecord. Args: image: uint8 tensor of shape ``[H W 3]`` mask: bool tensor of shape ``[H W C]``, where ``C`` is the number of classes. Example: When calling this function from a :mod:``tf.data`` pipeline, wrap it in a :func:``tf.py_function``:: >>> tf.py_function(serialize_image_and_mask, (image, mask), tf.string) Returns: A byte string. """ feature = { "image": tf.train.Feature( bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(image).numpy()]) ), "mask": tf.train.Feature( bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(mask).numpy()]) ), } example = tf.train.Example(features=tf.train.Features(feature=feature)) return example.SerializeToString()
[docs]@tf.function def deserialize_image_and_mask( example: bytes, *, img_size: ImgSizeHW, num_classes: int, ) -> Tuple[tf.Tensor, tf.Tensor]: """Deserialize an image and a mask as read from a TFRecord. Args: example: byte serialization of the image and mask. img_size: expected shape of the image ``(H, W)``. num_classes: expected number of classes ``C`` for the mask. Returns: One uint8 tensor of shape ``[H W 3]`` and one bool tensor of shape ``[H W, C]``. """ example = tf.io.parse_single_example( example, features={ "image": tf.io.FixedLenFeature([], tf.string), "mask": tf.io.FixedLenFeature([], tf.string), }, ) image = tf.io.parse_tensor(example["image"], tf.uint8) image = tf.ensure_shape(image, (*img_size, 3)) mask = tf.io.parse_tensor(example["mask"], tf.uint8) mask = tf.ensure_shape(mask, (*img_size, num_classes)) return image, mask