Source code for osc.data.objects_room

from pathlib import Path
from typing import Union

import IPython.display
import matplotlib.pyplot as plt
import multi_object_datasets.objects_room
import numpy as np
import tensorflow as tf

MODE_TO_FILE = {
    "six_objects": "objects_room_test_six_objects.tfrecords",
    "empty_room": "objects_room_test_empty_room.tfrecords",
    "identical_color": "objects_room_test_identical_color.tfrecords",
    "train": "objects_room_train.tfrecords",
}
MODES = list(MODE_TO_FILE.keys())
IMAGE_SIZE = multi_object_datasets.objects_room.IMAGE_SIZE


[docs]def show_sample(sample): fig, axs = plt.subplots( 1, 1 + sample["mask"].shape[0], figsize=3 * np.array([1 + sample["mask"].shape[0], 1]), sharex=True, sharey=True, ) axs[0].imshow(sample["image"], interpolation="none") axs[0].set_title("image") for m in range(sample["mask"].shape[0]): axs[m + 1].imshow(sample["mask"][m], cmap="gray", interpolation="none") axs[m + 1].set_title(f"mask {m}") fig.set_facecolor("white") IPython.display.display(fig) plt.close(fig)
[docs]def fix_tf_dtypes(sample): sample["mask"] = tf.cast(tf.squeeze(sample["mask"], -1), tf.bool) return sample
[docs]def get_iterator( data_dir: Union[str, Path], mode: str, map_parallel_calls: int = None, take: int = None, batch_size: int = None, drop_remainder=False, shuffle: int = None, numpy=True, ): tfr_path = Path(data_dir) / "objects_room" / MODE_TO_FILE[mode] ds = multi_object_datasets.objects_room.dataset( tfr_path.expanduser().resolve().as_posix(), mode, map_parallel_calls=map_parallel_calls, ) ds = ds.map(fix_tf_dtypes) if take is not None: ds = ds.take(take) if shuffle is not None: ds = ds.shuffle(shuffle, seed=0) if batch_size is not None: ds = ds.batch(batch_size, drop_remainder=drop_remainder) if numpy: ds = ds.as_numpy_iterator() return ds