Skip to content

Commit

Permalink
Put both fetcher variants behind a common interface
Browse files Browse the repository at this point in the history
  • Loading branch information
evansd committed Jan 17, 2025
1 parent 4a1d53f commit 61b2c35
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 19 deletions.
1 change: 1 addition & 0 deletions ehrql/query_engines/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def get_results(self, dataset):
execute_with_retry,
results_table,
key_column=results_table.c.patient_id,
key_is_unique=True,
# This value was copied from the previous cohortextractor. I suspect it
# has no real scientific basis.
batch_size=32000,
Expand Down
38 changes: 26 additions & 12 deletions ehrql/utils/sqlalchemy_exec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def fetch_table_in_batches(
execute, table, key_column, batch_size=32000, log=lambda *_: None
execute, table, key_column, key_is_unique, batch_size=32000, log=lambda *_: None
):
"""
Returns an iterator over all the rows in a table by querying it in batches
Expand All @@ -14,19 +14,42 @@ def fetch_table_in_batches(
execute: callable which accepts a SQLAlchemy query and returns results (can be
just a Connection.execute method)
table: SQLAlchemy TableClause
key_column: reference to a unique orderable column on `table`, used for
key_column: reference to an orderable column on `table`, used for
paging (note that this will need an index on it to avoid terrible
performance)
key_is_unique: if the key_column contains only unique values then we can use a
simpler and more efficient algorithm to do the paging
batch_size: how many results to fetch in each batch
log: callback to receive log messages
"""
if key_is_unique:
return fetch_table_in_batches_unique(
execute, table, key_column, batch_size, log
)
else:
return fetch_table_in_batches_nonunique(
execute, table, key_column, batch_size, log
)


def fetch_table_in_batches_unique(
execute, table, key_column, batch_size=32000, log=lambda *_: None
):
"""
Returns an iterator over all the rows in a table by querying it in batches using a
unique key column
"""
assert batch_size > 0
batch_count = 1
total_rows = 0
min_key = None

key_column_index = table.columns.values().index(key_column)

log(f"Fetching rows from '{table}' in batches of {batch_size}")
log(
f"Fetching rows from '{table}' in batches of {batch_size} using unique "
f"column '{key_column.name}'"
)
while True:
query = select(table).order_by(key_column).limit(batch_size)
if min_key is not None:
Expand Down Expand Up @@ -57,15 +80,6 @@ def fetch_table_in_batches_nonunique(
Returns an iterator over all the rows in a table by querying it in batches using a
non-unique key column
Args:
execute: callable which accepts a SQLAlchemy query and returns results (can be
just a Connection.execute method)
table: SQLAlchemy TableClause
key_column: reference to an orderable column on `table`, used for paging (note
that this will need an index on it to avoid terrible performance)
batch_size: how many results to fetch in each batch
log: callback to receive log messages
The algorithm below is designed (and tested) to work correctly without relying on
sort-stability. That is, if we repeatedly ask the database for results sorted by X
then rows with the same value for X may be returned in a different order each time.
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/utils/test_sqlalchemy_exec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class SomeTable(Base):
foo = sqlalchemy.Column(sqlalchemy.String)


def test_fetch_table_in_batches(engine):
def test_fetch_table_in_batches_unique(engine):
if engine.name == "in_memory":
pytest.skip("SQL tests do not apply to in-memory engine")

Expand All @@ -29,7 +29,11 @@ def test_fetch_table_in_batches(engine):

with engine.sqlalchemy_engine().connect() as connection:
results = fetch_table_in_batches(
connection.execute, table, table.c.pk, batch_size=batch_size
connection.execute,
table,
table.c.pk,
key_is_unique=True,
batch_size=batch_size,
)
results = list(results)

Expand Down
15 changes: 10 additions & 5 deletions tests/unit/utils/test_sqlalchemy_exec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ehrql.utils.sqlalchemy_exec_utils import (
execute_with_retry_factory,
fetch_table_in_batches,
fetch_table_in_batches_nonunique,
)


Expand Down Expand Up @@ -74,11 +73,15 @@ def sorted_data(self):
),
batch_size=st.integers(min_value=1, max_value=10),
)
def test_fetch_table_in_batches(table_data, batch_size):
def test_fetch_table_in_batches_unique(table_data, batch_size):
connection = FakeConnection(table_data)

results = fetch_table_in_batches(
connection.execute, sql_table, sql_table.c.key, batch_size=batch_size
connection.execute,
sql_table,
sql_table.c.key,
key_is_unique=True,
batch_size=batch_size,
)

assert sorted(results) == sorted(table_data)
Expand Down Expand Up @@ -129,10 +132,11 @@ def test_fetch_table_in_batches_nonunique(batch_size, table_data):
connection = FakeConnection(table_data)
log_messages = []

results = fetch_table_in_batches_nonunique(
results = fetch_table_in_batches(
connection.execute,
sql_table,
sql_table.c.key,
key_is_unique=False,
batch_size=batch_size,
log=log_messages.append,
)
Expand All @@ -159,10 +163,11 @@ def test_fetch_table_in_batches_nonunique_raises_if_batch_too_small(
):
connection = FakeConnection(table_data)

results = fetch_table_in_batches_nonunique(
results = fetch_table_in_batches(
connection.execute,
sql_table,
sql_table.c.key,
key_is_unique=False,
batch_size=batch_size,
)

Expand Down

0 comments on commit 61b2c35

Please sign in to comment.