import itertools
import torch
import torch.utils.data
[docs]class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, data_root, transform):
super(MyIterableDataset).__init__()
self.data_root = data_root
self.transform = transform
@staticmethod
def _slice_for_worker(ds_iter):
info = torch.utils.data.get_worker_info()
if info is None:
return ds_iter
return itertools.islice(ds_iter, info.id, None, info.num_workers)
def __iter__(self):
import osc.data.clevr_with_masks
ds_iter = osc.data.clevr_with_masks.get_iterator(self.data_root)
ds_iter = self._slice_for_worker(ds_iter)
for sample in ds_iter:
sample = self.transform(sample)
yield sample