Source code for osc.data.multi_dsprites

from pathlib import Path
from typing import Union

import IPython.display
import matplotlib.pyplot as plt
import multi_object_datasets.multi_dsprites
import numpy as np
import pandas as pd
import tensorflow as tf

MODES = ["colored_on_grayscale", "colored_on_colored", "binarized"]
MODE_TO_FILE = {mode: f"multi_dsprites_{mode}.tfrecords" for mode in MODES}
IMAGE_SIZE = multi_object_datasets.multi_dsprites.IMAGE_SIZE


[docs]def show_sample(sample): fig, axs = plt.subplots( 1, 1 + sample["mask"].shape[0], figsize=3 * np.array([1 + sample["mask"].shape[0], 1]), sharex=True, sharey=True, ) if sample["image"].shape[2] == 3: axs[0].imshow(sample["image"], interpolation="none") else: axs[0].imshow(sample["image"][:, :, 0], cmap="gray", interpolation="none") axs[0].set_title("image") for m in range(sample["mask"].shape[0]): axs[m + 1].imshow(sample["mask"][m], cmap="gray", interpolation="none") axs[m + 1].set_title(f"mask {m}") fig.set_facecolor("white") IPython.display.display(fig) plt.close(fig) IPython.display.display( pd.DataFrame( { "visibility": sample["visibility"], "x": sample["x"], "y": sample["y"], "shape": sample["shape"], "scale": sample["scale"], "orientation": sample["orientation"], "color_RGB": list(sample["color"]), } ) )
[docs]def fix_tf_dtypes(sample): sample["mask"] = tf.cast(tf.squeeze(sample["mask"], -1), tf.bool) sample["visibility"] = tf.cast(sample["visibility"], tf.bool) sample["shape"] = tf.cast(sample["shape"], tf.uint8) return sample
[docs]def get_iterator( data_dir: Union[str, Path], mode: str, map_parallel_calls: int = None, take: int = None, batch_size: int = None, drop_remainder=False, shuffle: int = None, numpy=True, ): tfr_path = Path(data_dir) / "multi_dsprites" / MODE_TO_FILE[mode] ds = multi_object_datasets.multi_dsprites.dataset( tfr_path.expanduser().resolve().as_posix(), mode, map_parallel_calls=map_parallel_calls, ) ds = ds.map(fix_tf_dtypes) if take is not None: ds = ds.take(take) if shuffle is not None: ds = ds.shuffle(shuffle, seed=0) if batch_size is not None: ds = ds.batch(batch_size, drop_remainder=drop_remainder) if numpy: ds = ds.as_numpy_iterator() return ds