From e258854378a8c341c9960d9d17f3b98311058311 Mon Sep 17 00:00:00 2001 From: Robert Gmyr Date: Wed, 14 Apr 2021 15:57:09 -0700 Subject: [PATCH 1/2] Add boundary_key feature to BucketedReadahead --- infinibatch/iterators.py | 14 +++- test/test_iterators.py | 138 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 143 insertions(+), 9 deletions(-) diff --git a/infinibatch/iterators.py b/infinibatch/iterators.py index ea80721..cbd029f 100755 --- a/infinibatch/iterators.py +++ b/infinibatch/iterators.py @@ -1390,13 +1390,14 @@ class BucketedReadaheadBatchIterator(CheckpointableIterator): This is based on Marian NMT's BatchGenerator. """ - def __init__(self, source_iterator: CheckpointableIterator, read_ahead: int, key: Callable[[Any], Any], batch_size: Union[int,Callable[[Any], int]], shuffle: bool=True, seed: int=0): + def __init__(self, source_iterator: CheckpointableIterator, read_ahead: int, key: Callable[[Any], Any], batch_size: Union[int,Callable[[Any], int]], boundary_key: Callable[[Any], Any]=None, shuffle: bool=True, seed: int=0): """ Args: source_iterator: The data set that is read from. Typically this is an infinite source. read_ahead: Number of items to fetch ahead for grouping purposes. key: User-provided callback to define how data is sorted for purpose of batching. batch_size: Batch size in number of items. Either an integer or a callback to determine batch size for a given first batch item. + boundary_key: This optional callback, which maps an item to a key, allows to impose an additional restriction on the way batches are formed. Specifically, the iterator starts a new batch whenever the key changes. Thereby, it guarantees that all items in a batch have the same key. Keys are not allowed to be None. shuffle: Pass False to not randomize the batches. (default: True) seed: Random seed for batch shuffling. """ @@ -1405,6 +1406,7 @@ def __init__(self, source_iterator: CheckpointableIterator, read_ahead: int, key # keep arguments self._key = key # type: Callable[[Any], Any] self._batch_size = batch_size # type: Union[int,Callable[[Any], int]] + self._boundary_key = boundary_key # type: Callable[[Any], Any] self._read_ahead = read_ahead # type: int # initialize state self._seed = seed @@ -1461,16 +1463,26 @@ def _create_batches(self, items: List[Any]) -> List[List[Any]]: # helper to for items.sort(key=self._key, reverse=True) # note: sort() is stable, so we won't undo any randomization besides the bucketing # group into batches cur_batch = None # type: Optional[List[Any]] + prev_val = None batches = [] # type: List[Any] for item in items: + if self._boundary_key and self._boundary_key(item) != prev_val: + if cur_batch: + batches.append(cur_batch) + cur_batch = None + prev_val = None if not cur_batch: batch_size = self._batch_size if isinstance(self._batch_size, int) else \ self._batch_size(item) cur_batch = [] cur_batch.append(item) + if self._boundary_key: + prev_val = self._boundary_key(item) + assert prev_val is not None if len(cur_batch) >= batch_size: # this batch is full batches.append(cur_batch) cur_batch = None + prev_val = None if cur_batch: batches.append(cur_batch) return batches diff --git a/test/test_iterators.py b/test/test_iterators.py index 8e96d53..d82bf86 100644 --- a/test/test_iterators.py +++ b/test/test_iterators.py @@ -754,6 +754,10 @@ def key_fn(item): def batch_size_fn(item): return TestBucketedReadaheadBatchIterator.dynamic_batch_size // len(item) + @staticmethod + def boundary_key_fn(item): + return len(item) < 5 + @staticmethod def setup_data(n): data = [] @@ -766,7 +770,87 @@ def setUp(self): self.batch_sizes = [1, 2, 3, 9] self.test_cases = [] - # fixed batch size, not shuffled + # fixed batch size, not shuffled, no boundary key + for n, read_ahead in itertools.product(self.lengths, self.lengths): + for batch_size in self.batch_sizes: + data = self.setup_data(n) + it = BucketedReadaheadBatchIterator( + NativeCheckpointableIterator(copy.deepcopy(data)), + read_ahead=read_ahead, + key=self.key_fn, + batch_size=batch_size, + shuffle=False, + ) + self.test_cases.append( + ( + "n={}, read_ahead={}, batch_size={}, boundary_key=None, shuffled=False".format( + n, read_ahead, batch_size + ), + data, + it, + ) + ) + + # fixed batch size, shuffled, no boundary key + for n, read_ahead in itertools.product(self.lengths, self.lengths): + for batch_size in self.batch_sizes: + data = self.setup_data(n) + it = BucketedReadaheadBatchIterator( + NativeCheckpointableIterator(copy.deepcopy(data)), + read_ahead=read_ahead, + key=self.key_fn, + batch_size=batch_size, + shuffle=True, + seed=self.seed, + ) + self.test_cases.append( + ( + "n={}, read_ahead={}, batch_size={}, boundary_key=None, shuffled=True".format( + n, read_ahead, batch_size + ), + data, + it, + ) + ) + + # dynamic batch size, not shuffled, no boundary key + for n, read_ahead in itertools.product(self.lengths, self.lengths): + data = self.setup_data(n) + it = BucketedReadaheadBatchIterator( + NativeCheckpointableIterator(copy.deepcopy(data)), + read_ahead=read_ahead, + key=self.key_fn, + batch_size=self.batch_size_fn, + shuffle=False, + ) + self.test_cases.append( + ( + "n={}, read_ahead={}, batch_size=dynamic, boundary_key=None, shuffled=False".format(n, read_ahead), + data, + it, + ) + ) + + # dynamic batch size, shuffled, no boundary key + for n, read_ahead in itertools.product(self.lengths, self.lengths): + data = self.setup_data(n) + it = BucketedReadaheadBatchIterator( + NativeCheckpointableIterator(copy.deepcopy(data)), + read_ahead=read_ahead, + key=self.key_fn, + batch_size=self.batch_size_fn, + shuffle=True, + seed=self.seed, + ) + self.test_cases.append( + ( + "n={}, read_ahead={}, batch_size=dynamic, boundary_key=None, shuffled=True".format(n, read_ahead), + data, + it, + ) + ) + + # fixed batch size, not shuffled, boundary key for n, read_ahead in itertools.product(self.lengths, self.lengths): for batch_size in self.batch_sizes: data = self.setup_data(n) @@ -775,13 +859,20 @@ def setUp(self): read_ahead=read_ahead, key=self.key_fn, batch_size=batch_size, + boundary_key=self.boundary_key_fn, shuffle=False, ) self.test_cases.append( - ("n={}, read_ahead={}, batch_size={}, shuffled=False".format(n, read_ahead, batch_size), data, it) + ( + "n={}, read_ahead={}, batch_size={}, boundary_key=len(item)<5, shuffled=False".format( + n, read_ahead, batch_size + ), + data, + it, + ) ) - # fixed batch size, shuffled + # fixed batch size, shuffled, boundary key for n, read_ahead in itertools.product(self.lengths, self.lengths): for batch_size in self.batch_sizes: data = self.setup_data(n) @@ -790,14 +881,21 @@ def setUp(self): read_ahead=read_ahead, key=self.key_fn, batch_size=batch_size, + boundary_key=self.boundary_key_fn, shuffle=True, seed=self.seed, ) self.test_cases.append( - ("n={}, read_ahead={}, batch_size={}, shuffled=True".format(n, read_ahead, batch_size), data, it) + ( + "n={}, read_ahead={}, batch_size={}, boundary_key=len(item)<5, shuffled=True".format( + n, read_ahead, batch_size + ), + data, + it, + ) ) - # dynamic batch size, not shuffled + # dynamic batch size, not shuffled, boundary key for n, read_ahead in itertools.product(self.lengths, self.lengths): data = self.setup_data(n) it = BucketedReadaheadBatchIterator( @@ -805,13 +903,21 @@ def setUp(self): read_ahead=read_ahead, key=self.key_fn, batch_size=self.batch_size_fn, + boundary_key=self.boundary_key_fn, shuffle=False, + seed=self.seed, ) self.test_cases.append( - ("n={}, read_ahead={}, batch_size=dynamic, shuffled=False".format(n, read_ahead), data, it) + ( + "n={}, read_ahead={}, batch_size=dynamic, boundary_key=len(item)<5, shuffled=False".format( + n, read_ahead + ), + data, + it, + ) ) - # dynamic batch size, shuffled + # dynamic batch size, shuffled, boundary key for n, read_ahead in itertools.product(self.lengths, self.lengths): data = self.setup_data(n) it = BucketedReadaheadBatchIterator( @@ -819,11 +925,18 @@ def setUp(self): read_ahead=read_ahead, key=self.key_fn, batch_size=self.batch_size_fn, + boundary_key=self.boundary_key_fn, shuffle=True, seed=self.seed, ) self.test_cases.append( - ("n={}, read_ahead={}, batch_size=dynamic, shuffled=True".format(n, read_ahead), data, it) + ( + "n={}, read_ahead={}, batch_size=dynamic, boundary_key=len(item)<5, shuffled=True".format( + n, read_ahead + ), + data, + it, + ) ) def test_basic(self): @@ -841,3 +954,12 @@ def test_max_len(self): for batch in result: length = sum((len(item) for item in batch)) self.assertTrue(length <= TestBucketedReadaheadBatchIterator.dynamic_batch_size) + + def test_boundary_key(self): + for case_name, expected_result, it in self.test_cases: + if "boundary_key=len(item)<5" in case_name: + with self.subTest(case_name): + result = list(it) + for batch in result: + boundary_keys = (self.boundary_key_fn(item) for item in batch) + self.assertTrue(all(boundary_keys) or not any(boundary_keys)) From 6aa69c1614b77f9357050bc412e1d0a6d55d3a2c Mon Sep 17 00:00:00 2001 From: Robert Gmyr Date: Thu, 15 Apr 2021 13:31:49 -0700 Subject: [PATCH 2/2] Fix issue in unit test --- test/test_iterators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_iterators.py b/test/test_iterators.py index d82bf86..52c3da5 100644 --- a/test/test_iterators.py +++ b/test/test_iterators.py @@ -961,5 +961,5 @@ def test_boundary_key(self): with self.subTest(case_name): result = list(it) for batch in result: - boundary_keys = (self.boundary_key_fn(item) for item in batch) + boundary_keys = [self.boundary_key_fn(item) for item in batch] self.assertTrue(all(boundary_keys) or not any(boundary_keys))