Skip to content

Commit

Permalink
Add boundary_key feature to BucketedReadahead
Browse files Browse the repository at this point in the history
  • Loading branch information
gmyr committed Apr 14, 2021
1 parent f1eea7c commit e258854
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 9 deletions.
14 changes: 13 additions & 1 deletion infinibatch/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,13 +1390,14 @@ class BucketedReadaheadBatchIterator(CheckpointableIterator):
This is based on Marian NMT's BatchGenerator.
"""

def __init__(self, source_iterator: CheckpointableIterator, read_ahead: int, key: Callable[[Any], Any], batch_size: Union[int,Callable[[Any], int]], shuffle: bool=True, seed: int=0):
def __init__(self, source_iterator: CheckpointableIterator, read_ahead: int, key: Callable[[Any], Any], batch_size: Union[int,Callable[[Any], int]], boundary_key: Callable[[Any], Any]=None, shuffle: bool=True, seed: int=0):
"""
Args:
source_iterator: The data set that is read from. Typically this is an infinite source.
read_ahead: Number of items to fetch ahead for grouping purposes.
key: User-provided callback to define how data is sorted for purpose of batching.
batch_size: Batch size in number of items. Either an integer or a callback to determine batch size for a given first batch item.
boundary_key: This optional callback, which maps an item to a key, allows to impose an additional restriction on the way batches are formed. Specifically, the iterator starts a new batch whenever the key changes. Thereby, it guarantees that all items in a batch have the same key. Keys are not allowed to be None.
shuffle: Pass False to not randomize the batches. (default: True)
seed: Random seed for batch shuffling.
"""
Expand All @@ -1405,6 +1406,7 @@ def __init__(self, source_iterator: CheckpointableIterator, read_ahead: int, key
# keep arguments
self._key = key # type: Callable[[Any], Any]
self._batch_size = batch_size # type: Union[int,Callable[[Any], int]]
self._boundary_key = boundary_key # type: Callable[[Any], Any]
self._read_ahead = read_ahead # type: int
# initialize state
self._seed = seed
Expand Down Expand Up @@ -1461,16 +1463,26 @@ def _create_batches(self, items: List[Any]) -> List[List[Any]]: # helper to for
items.sort(key=self._key, reverse=True) # note: sort() is stable, so we won't undo any randomization besides the bucketing
# group into batches
cur_batch = None # type: Optional[List[Any]]
prev_val = None
batches = [] # type: List[Any]
for item in items:
if self._boundary_key and self._boundary_key(item) != prev_val:
if cur_batch:
batches.append(cur_batch)
cur_batch = None
prev_val = None
if not cur_batch:
batch_size = self._batch_size if isinstance(self._batch_size, int) else \
self._batch_size(item)
cur_batch = []
cur_batch.append(item)
if self._boundary_key:
prev_val = self._boundary_key(item)
assert prev_val is not None
if len(cur_batch) >= batch_size: # this batch is full
batches.append(cur_batch)
cur_batch = None
prev_val = None
if cur_batch:
batches.append(cur_batch)
return batches
Expand Down
138 changes: 130 additions & 8 deletions test/test_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,10 @@ def key_fn(item):
def batch_size_fn(item):
return TestBucketedReadaheadBatchIterator.dynamic_batch_size // len(item)

@staticmethod
def boundary_key_fn(item):
return len(item) < 5

@staticmethod
def setup_data(n):
data = []
Expand All @@ -766,7 +770,87 @@ def setUp(self):
self.batch_sizes = [1, 2, 3, 9]
self.test_cases = []

# fixed batch size, not shuffled
# fixed batch size, not shuffled, no boundary key
for n, read_ahead in itertools.product(self.lengths, self.lengths):
for batch_size in self.batch_sizes:
data = self.setup_data(n)
it = BucketedReadaheadBatchIterator(
NativeCheckpointableIterator(copy.deepcopy(data)),
read_ahead=read_ahead,
key=self.key_fn,
batch_size=batch_size,
shuffle=False,
)
self.test_cases.append(
(
"n={}, read_ahead={}, batch_size={}, boundary_key=None, shuffled=False".format(
n, read_ahead, batch_size
),
data,
it,
)
)

# fixed batch size, shuffled, no boundary key
for n, read_ahead in itertools.product(self.lengths, self.lengths):
for batch_size in self.batch_sizes:
data = self.setup_data(n)
it = BucketedReadaheadBatchIterator(
NativeCheckpointableIterator(copy.deepcopy(data)),
read_ahead=read_ahead,
key=self.key_fn,
batch_size=batch_size,
shuffle=True,
seed=self.seed,
)
self.test_cases.append(
(
"n={}, read_ahead={}, batch_size={}, boundary_key=None, shuffled=True".format(
n, read_ahead, batch_size
),
data,
it,
)
)

# dynamic batch size, not shuffled, no boundary key
for n, read_ahead in itertools.product(self.lengths, self.lengths):
data = self.setup_data(n)
it = BucketedReadaheadBatchIterator(
NativeCheckpointableIterator(copy.deepcopy(data)),
read_ahead=read_ahead,
key=self.key_fn,
batch_size=self.batch_size_fn,
shuffle=False,
)
self.test_cases.append(
(
"n={}, read_ahead={}, batch_size=dynamic, boundary_key=None, shuffled=False".format(n, read_ahead),
data,
it,
)
)

# dynamic batch size, shuffled, no boundary key
for n, read_ahead in itertools.product(self.lengths, self.lengths):
data = self.setup_data(n)
it = BucketedReadaheadBatchIterator(
NativeCheckpointableIterator(copy.deepcopy(data)),
read_ahead=read_ahead,
key=self.key_fn,
batch_size=self.batch_size_fn,
shuffle=True,
seed=self.seed,
)
self.test_cases.append(
(
"n={}, read_ahead={}, batch_size=dynamic, boundary_key=None, shuffled=True".format(n, read_ahead),
data,
it,
)
)

# fixed batch size, not shuffled, boundary key
for n, read_ahead in itertools.product(self.lengths, self.lengths):
for batch_size in self.batch_sizes:
data = self.setup_data(n)
Expand All @@ -775,13 +859,20 @@ def setUp(self):
read_ahead=read_ahead,
key=self.key_fn,
batch_size=batch_size,
boundary_key=self.boundary_key_fn,
shuffle=False,
)
self.test_cases.append(
("n={}, read_ahead={}, batch_size={}, shuffled=False".format(n, read_ahead, batch_size), data, it)
(
"n={}, read_ahead={}, batch_size={}, boundary_key=len(item)<5, shuffled=False".format(
n, read_ahead, batch_size
),
data,
it,
)
)

# fixed batch size, shuffled
# fixed batch size, shuffled, boundary key
for n, read_ahead in itertools.product(self.lengths, self.lengths):
for batch_size in self.batch_sizes:
data = self.setup_data(n)
Expand All @@ -790,40 +881,62 @@ def setUp(self):
read_ahead=read_ahead,
key=self.key_fn,
batch_size=batch_size,
boundary_key=self.boundary_key_fn,
shuffle=True,
seed=self.seed,
)
self.test_cases.append(
("n={}, read_ahead={}, batch_size={}, shuffled=True".format(n, read_ahead, batch_size), data, it)
(
"n={}, read_ahead={}, batch_size={}, boundary_key=len(item)<5, shuffled=True".format(
n, read_ahead, batch_size
),
data,
it,
)
)

# dynamic batch size, not shuffled
# dynamic batch size, not shuffled, boundary key
for n, read_ahead in itertools.product(self.lengths, self.lengths):
data = self.setup_data(n)
it = BucketedReadaheadBatchIterator(
NativeCheckpointableIterator(copy.deepcopy(data)),
read_ahead=read_ahead,
key=self.key_fn,
batch_size=self.batch_size_fn,
boundary_key=self.boundary_key_fn,
shuffle=False,
seed=self.seed,
)
self.test_cases.append(
("n={}, read_ahead={}, batch_size=dynamic, shuffled=False".format(n, read_ahead), data, it)
(
"n={}, read_ahead={}, batch_size=dynamic, boundary_key=len(item)<5, shuffled=False".format(
n, read_ahead
),
data,
it,
)
)

# dynamic batch size, shuffled
# dynamic batch size, shuffled, boundary key
for n, read_ahead in itertools.product(self.lengths, self.lengths):
data = self.setup_data(n)
it = BucketedReadaheadBatchIterator(
NativeCheckpointableIterator(copy.deepcopy(data)),
read_ahead=read_ahead,
key=self.key_fn,
batch_size=self.batch_size_fn,
boundary_key=self.boundary_key_fn,
shuffle=True,
seed=self.seed,
)
self.test_cases.append(
("n={}, read_ahead={}, batch_size=dynamic, shuffled=True".format(n, read_ahead), data, it)
(
"n={}, read_ahead={}, batch_size=dynamic, boundary_key=len(item)<5, shuffled=True".format(
n, read_ahead
),
data,
it,
)
)

def test_basic(self):
Expand All @@ -841,3 +954,12 @@ def test_max_len(self):
for batch in result:
length = sum((len(item) for item in batch))
self.assertTrue(length <= TestBucketedReadaheadBatchIterator.dynamic_batch_size)

def test_boundary_key(self):
for case_name, expected_result, it in self.test_cases:
if "boundary_key=len(item)<5" in case_name:
with self.subTest(case_name):
result = list(it)
for batch in result:
boundary_keys = (self.boundary_key_fn(item) for item in batch)
self.assertTrue(all(boundary_keys) or not any(boundary_keys))

0 comments on commit e258854

Please sign in to comment.