Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement wrapper iterator that inherits from PyTorch's IterableDataset #14

Open
gmyr opened this issue Aug 11, 2020 · 1 comment
Open

Comments

@gmyr
Copy link
Collaborator

gmyr commented Aug 11, 2020

We might want to implement an Iterator that inherits from PyTorch's IterableDataset to have a direct interface to PyTorch's data loader functionality.

Here is some prototype code that we had earlier in this direction.

class IterableCheckpointedDataset(torch.utils.data.IterableDataset):
    """
    Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by
    PyTorch's DataLoader class.
    """
    def __init__(self, source: CheckpointableIterator):
        super().__init__()
        self._source = source

    def __iter__(self):  # this is called in the forked clone
        worker_info = torch.utils.data.get_worker_info()
        assert worker_info is None or worker_info.num_workers == 1  # not supported since we can't get at the checkpoint for each worker
        return iter(self._source)


class IterableChunkedDataset(torch.utils.data.IterableDataset):
    def __init__(self, paths: Union[str, Iterable[str]], shuffle: bool=True, buffer_size: int=2**20, transform=None, seed: int=None, world_size: int=1, rank: int=0, num_workers_per_rank: int=1):
        super().__init__()
        self.rank = rank
        self.num_workers_per_rank = num_workers_per_rank
        # instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__
        self.dataset = chunked_dataset_iterator(paths, shuffle=shuffle, buffer_size=buffer_size, transform=transform, seed=seed, num_instances=world_size*num_workers_per_rank, instance_rank=rank)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading
            self.dataset._instance_rank = self.rank
        else:
            assert worker_info.num_workers == self.num_workers_per_rank
            self.dataset._instance_rank = self.rank * self.num_workers_per_rank + worker_info.id
        return iter(self.dataset)
@sai-prasanna
Copy link

sai-prasanna commented Sep 11, 2020

Found an issue in the above example. The iter(self.dataset) call doesn't reset the iterator. So if a non-infinite validation iterator is used up it the __iter__ call returns an empty iterator. I guess we have to use set_state for __iter__ calls for validation datasets.

And in case of training we can assume it's infinite and just return without setting the state.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants