Skip to content

Commit

Permalink
Merge pull request #29 from gmyr/too_few_chunks
Browse files Browse the repository at this point in the history
Fix bugs for edge cases in InfinitePermutationSourceIterator
  • Loading branch information
gmyr authored Mar 6, 2021
2 parents 507aa65 + 0214897 commit f1eea7c
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 115 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ To run unit tests, run the following command.
```
python -m unittest discover -s test
```
If you would like the unit tests to stop after the first failed test, use:
```
python -m unittest discover -s test --failfast
```
To type-check with `mypy` (if installed):
```
mypy infinibatch
Expand Down
95 changes: 57 additions & 38 deletions infinibatch/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,68 +387,87 @@ class InfinitePermutationSourceIterator(CheckpointableIterator):
For example, this is used for randomizing the pathnames of data blocks read by ChunkedReadlinesIterator.
"""
def __init__(self, source_items: List, seed: Optional[int]=0, shuffle: bool=True, num_instances: int=1, instance_rank: int=0):

def __init__(
self,
source_items: List,
seed: int = 0,
shuffle: bool = True,
num_instances: int = 1,
instance_rank: int = 0,
):
"""
Args:
source_items: input list, must not be empty and must be small enough to fit into RAM entirely, ownership of the list and the data goes to the iterator, do not modify it!
source_items: input list, must not be empty, must be small enough to fit into RAM entirely, and must support deepcopies
seed: random seed used for shuffling
shuffle: set False to bypass the shuffling. Then this is just a checkpointed version of itertools.cycle(). (Default: True)
num_instances: number of instances of this iterator. Meant for use with multi-process data loading, e.g., in distributed training.
instance_rank: rank of this instance of the iterator. Meant for use with multi-process data loading, e.g., in distributed training.
"""
if not source_items:
raise ValueError("source must not be empty")
self._source_items = source_items
self._shuffle = shuffle
self._seed = seed
if instance_rank >= num_instances:
raise ValueError("invalid instance_rank")
self._source_items = copy.deepcopy(source_items)
self._shuffle = shuffle
self._seed = seed
self._num_instances = num_instances
self._instance_rank = instance_rank
self.setstate(None)

def getstate(self) -> Dict:
return {'random_state': self._random_state, # state of random generator before generating the current shuffling of the sequence
'num_items_yielded': self._num_items_yielded} # how many items have already been iterated over in the current shuffling
return {"random_state": self._random_state, "index": self._index}

def setstate(self, checkpoint: Optional[Dict]):
# set iteration state. Do this outside the generator below in case getstate() is called before ever iterating
self._random_state = checkpoint['random_state'] if checkpoint else None
self._num_items_yielded = checkpoint['num_items_yielded'] if checkpoint else 0
# We define the iteration itself as a generator for ease of implementation.
# We could as well just have used an explicit state machine represented by class members.
def _generate() -> Iterator:
# create and reset random generator
random = Random(self._seed)
if self._random_state is not None: # restore the random generator's state
random.setstate(self._random_state)
skip_to_checkpoint = self._num_items_yielded # items to skip in order to advance to checkpoint
# main outer loop for infinite passes over items (reshuffle before each pass)
while True:
# (re-)shuffle all items
self._random_state = random.getstate() # remember random state before shuffling
self._num_items_yielded = 0
shuffled_items = self._source_items[:] # note: if underlying iterator is checkpointable, use setstate(checkpoint['nested_state']) on it
if self._shuffle:
random.shuffle(shuffled_items)
shuffled_iterator = iter(shuffled_items)
# skip initial items when restarting from checkpoint
if skip_to_checkpoint: # @TODO: find a way to abstract this more, so that we can plug it into the 'for' statement directly
self._num_items_yielded += _advance_iterator(shuffled_iterator, skip_to_checkpoint)
skip_to_checkpoint = 0 # done skipping
# main inner loop over items
for item in shuffled_iterator:
self._num_items_yielded += 1 # record how many items we have iterated over in this pass over the items
if (self._num_items_yielded-1) % self._num_instances == self._instance_rank: # build-in islice facility
yield item
self._iterator = _generate()
self._random_state = checkpoint["random_state"] if checkpoint else None
self._index = checkpoint["index"] if checkpoint else self._instance_rank

self._random = None # this will trigger the lazy initialization in self.__next__

def __next__(self):
return next(self._iterator)
if self._random == None:
# lazy initialization
self._random = Random(self._seed)
if self._random_state is not None:
self._random.setstate(self._random_state)
if self._shuffle:
self._reshuffle() # create initial permutation
self._reshuffle_as_necessary() # reshuffle as often as necesary to bring self._index into range
else:
self._index = self._index % len(self._source_items)

assert 0 <= self._index and self._index < len(self._source_items)
if self._shuffle:
result = self._shuffled_items[self._index]
self._index += self._num_instances
self._reshuffle_as_necessary() # reshuffle as often as necesary to bring self._index into range
else:
result = self._source_items[self._index]
self._index = (self._index + self._num_instances) % len(self._source_items)
assert 0 <= self._index and self._index < len(self._source_items)
return result

def close(self):
pass

def _reshuffle_as_necessary(self):
while self._index >= len(self._source_items):
# The new index is out of range, so we need to reshuffle.
# Since len(self._source_items) can be smaller than self._num_instances,
# we might have to reshuffle multiple times to "skip through" permutations of self._source_items.
# Even though there might be intermediate permutations that are not actually used,
# we have to generate all of them to make sure we get the right RNG state
# to guarantee correctness when using multiple instances.
self._reshuffle()
self._index -= len(self._source_items)

def _reshuffle(self):
self._random_state = self._random.getstate()
self._shuffled_items = copy.deepcopy(self._source_items)
self._random.shuffle(self._shuffled_items)




class MultiplexIterator(CheckpointableIterator):
"""
Expand Down
Loading

0 comments on commit f1eea7c

Please sign in to comment.