diff --git a/ehrql/query_engines/mssql.py b/ehrql/query_engines/mssql.py index 788d976f4..1b9c803b7 100644 --- a/ehrql/query_engines/mssql.py +++ b/ehrql/query_engines/mssql.py @@ -197,6 +197,7 @@ def get_results(self, dataset): execute_with_retry, results_table, key_column=results_table.c.patient_id, + key_is_unique=True, # This value was copied from the previous cohortextractor. I suspect it # has no real scientific basis. batch_size=32000, diff --git a/ehrql/utils/sqlalchemy_exec_utils.py b/ehrql/utils/sqlalchemy_exec_utils.py index c8e19eb57..296adcc98 100644 --- a/ehrql/utils/sqlalchemy_exec_utils.py +++ b/ehrql/utils/sqlalchemy_exec_utils.py @@ -5,7 +5,7 @@ def fetch_table_in_batches( - execute, table, key_column, batch_size=32000, log=lambda *_: None + execute, table, key_column, key_is_unique, batch_size=32000, log=lambda *_: None ): """ Returns an iterator over all the rows in a table by querying it in batches @@ -14,19 +14,42 @@ def fetch_table_in_batches( execute: callable which accepts a SQLAlchemy query and returns results (can be just a Connection.execute method) table: SQLAlchemy TableClause - key_column: reference to a unique orderable column on `table`, used for + key_column: reference to an orderable column on `table`, used for paging (note that this will need an index on it to avoid terrible performance) + key_is_unique: if the key_column contains only unique values then we can use a + simpler and more efficient algorithm to do the paging batch_size: how many results to fetch in each batch log: callback to receive log messages """ + if key_is_unique: + return fetch_table_in_batches_unique( + execute, table, key_column, batch_size, log + ) + else: + return fetch_table_in_batches_nonunique( + execute, table, key_column, batch_size, log + ) + + +def fetch_table_in_batches_unique( + execute, table, key_column, batch_size=32000, log=lambda *_: None +): + """ + Returns an iterator over all the rows in a table by querying it in batches using a + unique key column + """ + assert batch_size > 0 batch_count = 1 total_rows = 0 min_key = None key_column_index = table.columns.values().index(key_column) - log(f"Fetching rows from '{table}' in batches of {batch_size}") + log( + f"Fetching rows from '{table}' in batches of {batch_size} using unique " + f"column '{key_column.name}'" + ) while True: query = select(table).order_by(key_column).limit(batch_size) if min_key is not None: @@ -57,15 +80,6 @@ def fetch_table_in_batches_nonunique( Returns an iterator over all the rows in a table by querying it in batches using a non-unique key column - Args: - execute: callable which accepts a SQLAlchemy query and returns results (can be - just a Connection.execute method) - table: SQLAlchemy TableClause - key_column: reference to an orderable column on `table`, used for paging (note - that this will need an index on it to avoid terrible performance) - batch_size: how many results to fetch in each batch - log: callback to receive log messages - The algorithm below is designed (and tested) to work correctly without relying on sort-stability. That is, if we repeatedly ask the database for results sorted by X then rows with the same value for X may be returned in a different order each time. diff --git a/tests/integration/utils/test_sqlalchemy_exec_utils.py b/tests/integration/utils/test_sqlalchemy_exec_utils.py index bce31a026..bcc1c91a9 100644 --- a/tests/integration/utils/test_sqlalchemy_exec_utils.py +++ b/tests/integration/utils/test_sqlalchemy_exec_utils.py @@ -14,7 +14,7 @@ class SomeTable(Base): foo = sqlalchemy.Column(sqlalchemy.String) -def test_fetch_table_in_batches(engine): +def test_fetch_table_in_batches_unique(engine): if engine.name == "in_memory": pytest.skip("SQL tests do not apply to in-memory engine") @@ -29,7 +29,11 @@ def test_fetch_table_in_batches(engine): with engine.sqlalchemy_engine().connect() as connection: results = fetch_table_in_batches( - connection.execute, table, table.c.pk, batch_size=batch_size + connection.execute, + table, + table.c.pk, + key_is_unique=True, + batch_size=batch_size, ) results = list(results) diff --git a/tests/unit/utils/test_sqlalchemy_exec_utils.py b/tests/unit/utils/test_sqlalchemy_exec_utils.py index 117ad6a84..6072bf2f6 100644 --- a/tests/unit/utils/test_sqlalchemy_exec_utils.py +++ b/tests/unit/utils/test_sqlalchemy_exec_utils.py @@ -10,7 +10,6 @@ from ehrql.utils.sqlalchemy_exec_utils import ( execute_with_retry_factory, fetch_table_in_batches, - fetch_table_in_batches_nonunique, ) @@ -74,11 +73,15 @@ def sorted_data(self): ), batch_size=st.integers(min_value=1, max_value=10), ) -def test_fetch_table_in_batches(table_data, batch_size): +def test_fetch_table_in_batches_unique(table_data, batch_size): connection = FakeConnection(table_data) results = fetch_table_in_batches( - connection.execute, sql_table, sql_table.c.key, batch_size=batch_size + connection.execute, + sql_table, + sql_table.c.key, + key_is_unique=True, + batch_size=batch_size, ) assert sorted(results) == sorted(table_data) @@ -129,10 +132,11 @@ def test_fetch_table_in_batches_nonunique(batch_size, table_data): connection = FakeConnection(table_data) log_messages = [] - results = fetch_table_in_batches_nonunique( + results = fetch_table_in_batches( connection.execute, sql_table, sql_table.c.key, + key_is_unique=False, batch_size=batch_size, log=log_messages.append, ) @@ -159,10 +163,11 @@ def test_fetch_table_in_batches_nonunique_raises_if_batch_too_small( ): connection = FakeConnection(table_data) - results = fetch_table_in_batches_nonunique( + results = fetch_table_in_batches( connection.execute, sql_table, sql_table.c.key, + key_is_unique=False, batch_size=batch_size, )