Source code for osc.data.iterable_dataloader

import itertools

import torch
import torch.utils.data


[docs]class MyIterableDataset(torch.utils.data.IterableDataset): def __init__(self, data_root, transform): super(MyIterableDataset).__init__() self.data_root = data_root self.transform = transform @staticmethod def _slice_for_worker(ds_iter): info = torch.utils.data.get_worker_info() if info is None: return ds_iter return itertools.islice(ds_iter, info.id, None, info.num_workers) def __iter__(self): import osc.data.clevr_with_masks ds_iter = osc.data.clevr_with_masks.get_iterator(self.data_root) ds_iter = self._slice_for_worker(ds_iter) for sample in ds_iter: sample = self.transform(sample) yield sample