You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We might want to implement an Iterator that inherits from PyTorch's IterableDataset to have a direct interface to PyTorch's data loader functionality.
Here is some prototype code that we had earlier in this direction.
class IterableCheckpointedDataset(torch.utils.data.IterableDataset):
"""
Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by
PyTorch's DataLoader class.
"""
def __init__(self, source: CheckpointableIterator):
super().__init__()
self._source = source
def __iter__(self): # this is called in the forked clone
worker_info = torch.utils.data.get_worker_info()
assert worker_info is None or worker_info.num_workers == 1 # not supported since we can't get at the checkpoint for each worker
return iter(self._source)
class IterableChunkedDataset(torch.utils.data.IterableDataset):
def __init__(self, paths: Union[str, Iterable[str]], shuffle: bool=True, buffer_size: int=2**20, transform=None, seed: int=None, world_size: int=1, rank: int=0, num_workers_per_rank: int=1):
super().__init__()
self.rank = rank
self.num_workers_per_rank = num_workers_per_rank
# instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__
self.dataset = chunked_dataset_iterator(paths, shuffle=shuffle, buffer_size=buffer_size, transform=transform, seed=seed, num_instances=world_size*num_workers_per_rank, instance_rank=rank)
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading
self.dataset._instance_rank = self.rank
else:
assert worker_info.num_workers == self.num_workers_per_rank
self.dataset._instance_rank = self.rank * self.num_workers_per_rank + worker_info.id
return iter(self.dataset)
The text was updated successfully, but these errors were encountered:
Found an issue in the above example. The iter(self.dataset) call doesn't reset the iterator. So if a non-infinite validation iterator is used up it the __iter__ call returns an empty iterator. I guess we have to use set_state for __iter__ calls for validation datasets.
And in case of training we can assume it's infinite and just return without setting the state.
We might want to implement an Iterator that inherits from PyTorch's IterableDataset to have a direct interface to PyTorch's data loader functionality.
Here is some prototype code that we had earlier in this direction.
The text was updated successfully, but these errors were encountered: