Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update QueryEngine API to support multiple results tables #2363

Open
wants to merge 11 commits into
base: evansd/batch-download-nonunique
Choose a base branch
from
28 changes: 23 additions & 5 deletions ehrql/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,13 +662,31 @@ def existing_python_file(value):


def valid_output_path(value):
# This can be either a single file or a directory, but either way it needs to
# specify a valid output format
path = Path(value)
extension = get_file_extension(path)
if extension not in FILE_FORMATS:
directory_ext = split_directory_and_extension(path)[1]
file_ext = get_file_extension(path)
if not directory_ext and not file_ext:
raise ArgumentTypeError(
f"'{extension}' is not a supported format, must be one of: "
f"{backtick_join(FILE_FORMATS)}"
f"No file format supplied\n"
f"To write multiple files use a directory extension: "
f"{backtick_join(format_directory_extension(e) for e in FILE_FORMATS)}\n"
f"To write a single file use a file extension: {backtick_join(FILE_FORMATS)}"
)
elif directory_ext:
if directory_ext not in FILE_FORMATS:
raise ArgumentTypeError(
f"'{format_directory_extension(directory_ext)}' is not a supported format, "
f"must be one of: "
f"{backtick_join(format_directory_extension(e) for e in FILE_FORMATS)}"
)
else:
if file_ext not in FILE_FORMATS:
raise ArgumentTypeError(
f"'{file_ext}' is not a supported format, must be one of: "
f"{backtick_join(FILE_FORMATS)}"
)
return path


Expand Down Expand Up @@ -701,7 +719,7 @@ def query_engine_from_id(str_id):
f"(or a full dotted path to a query engine class)"
)
query_engine = import_string(str_id)
assert_duck_type(query_engine, "query engine", "get_results")
assert_duck_type(query_engine, "query engine", "get_results_tables")
return query_engine


Expand Down
10 changes: 8 additions & 2 deletions ehrql/dummy_data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,16 @@ def get_patient_id_stream(self):
if i not in inline_patient_ids:
yield i

def get_results(self):
def get_results_tables(self):
database = InMemoryDatabase(self.get_data())
engine = InMemoryQueryEngine(database)
return engine.get_results(self.dataset)
return engine.get_results_tables(self.dataset)

def get_results(self):
tables = self.get_results_tables()
yield from next(tables)
for remaining in tables:
assert False, "Expected only one results table"


class DummyPatientGenerator:
Expand Down
10 changes: 8 additions & 2 deletions ehrql/dummy_data_nextgen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,16 @@ def get_patient_id_stream(self):
if i not in inline_patient_ids:
yield i

def get_results(self):
def get_results_tables(self):
database = InMemoryDatabase(self.get_data())
engine = InMemoryQueryEngine(database)
return engine.get_results(self.dataset)
return engine.get_results_tables(self.dataset)

def get_results(self):
tables = self.get_results_tables()
yield from next(tables)
for remaining in tables:
assert False, "Expected only one results table"


class DummyPatientGenerator:
Expand Down
47 changes: 47 additions & 0 deletions ehrql/file_formats/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ def read_rows(filename, column_specs, allow_missing_columns=False):


def read_tables(filename, table_specs, allow_missing_columns=False):
# If we've got a single-table input file and only a single table to read then that's
# fine, but it needs slightly special handling
if not input_filename_supports_multiple_tables(filename):
if len(table_specs) == 1:
column_specs = list(table_specs.values())[0]
rows = read_rows(
filename,
column_specs,
allow_missing_columns=allow_missing_columns,
)
yield from [rows]
return
else:
raise FileValidationError(
f"Attempting to read {len(table_specs)} tables, but '{filename}' "
f"only supports a single table"
)

extension = get_extension_from_directory(filename)
# Using ExitStack here allows us to open and validate all files before emiting any
# rows while still correctly closing all open files if we raise an error part way
Expand All @@ -66,6 +84,19 @@ def read_tables(filename, table_specs, allow_missing_columns=False):


def write_tables(filename, tables, table_specs):
# If we've got a single-table output file and only a single table to write then
# that's fine, but it needs slightly special handling
if not output_filename_supports_multiple_tables(filename):
if len(table_specs) == 1:
column_specs = list(table_specs.values())[0]
rows = next(iter(tables))
return write_rows(filename, rows, column_specs)
else:
raise FileValidationError(
f"Attempting to write {len(table_specs)} tables, but '{filename}' "
f"only supports a single table"
)

filename, extension = split_directory_and_extension(filename)
for rows, (table_name, column_specs) in zip(tables, table_specs.items()):
table_filename = get_table_filename(filename, table_name, extension)
Expand Down Expand Up @@ -121,6 +152,22 @@ def split_directory_and_extension(filename):
return filename.with_name(name), f".{extension}"


def input_filename_supports_multiple_tables(filename):
# At present, supplying a directory is the only way to provide multiple input
# tables, but it's not inconceivable that in future we might support single-file
# multiple-table formats e.g SQLite or DuckDB files. If we do then updating this
# function and its sibling below should be all that's required.
return filename.is_dir()


def output_filename_supports_multiple_tables(filename):
if filename is None:
return False
# Again, at present only directories support multiple output tables but see above
extension = split_directory_and_extension(filename)[1]
return extension != ""


def get_table_filename(base_filename, table_name, extension):
# Use URL quoting as an easy way of escaping any potentially problematic characters
# in filenames
Expand Down
35 changes: 19 additions & 16 deletions ehrql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from ehrql.file_formats import (
read_rows,
read_tables,
split_directory_and_extension,
write_rows,
write_tables,
Expand All @@ -35,8 +36,8 @@
from ehrql.query_engines.local_file import LocalFileQueryEngine
from ehrql.query_engines.sqlite import SQLiteQueryEngine
from ehrql.query_model.column_specs import (
get_column_specs,
get_column_specs_from_schema,
get_table_specs,
)
from ehrql.query_model.graphs import graph_to_svg
from ehrql.serializer import serialize
Expand Down Expand Up @@ -71,11 +72,11 @@ def generate_dataset(
log.info(f"Testing dataset definition with tests in {str(definition_file)}")
assure(test_data_file, environ=environ, user_args=user_args)

column_specs = get_column_specs(dataset)
table_specs = get_table_specs(dataset)

if dsn:
log.info("Generating dataset")
results = generate_dataset_with_dsn(
results_tables = generate_dataset_with_dsn(
dataset=dataset,
dsn=dsn,
backend_class=backend_class,
Expand All @@ -84,15 +85,15 @@ def generate_dataset(
)
else:
log.info("Generating dummy dataset")
results = generate_dataset_with_dummy_data(
results_tables = generate_dataset_with_dummy_data(
dataset=dataset,
dummy_data_config=dummy_data_config,
column_specs=column_specs,
table_specs=table_specs,
dummy_data_file=dummy_data_file,
dummy_tables_path=dummy_tables_path,
)

write_rows(output_file, results, column_specs)
write_tables(output_file, results_tables, table_specs)


def generate_dataset_with_dsn(
Expand All @@ -105,23 +106,22 @@ def generate_dataset_with_dsn(
environ,
default_query_engine_class=LocalFileQueryEngine,
)
return query_engine.get_results(dataset)
return query_engine.get_results_tables(dataset)


def generate_dataset_with_dummy_data(
*, dataset, dummy_data_config, column_specs, dummy_data_file, dummy_tables_path
*, dataset, dummy_data_config, table_specs, dummy_data_file, dummy_tables_path
):
if dummy_data_file:
log.info(f"Reading dummy data from {dummy_data_file}")
reader = read_rows(dummy_data_file, column_specs)
return iter(reader)
return read_tables(dummy_data_file, table_specs)
elif dummy_tables_path:
log.info(f"Reading table data from {dummy_tables_path}")
query_engine = LocalFileQueryEngine(dummy_tables_path)
return query_engine.get_results(dataset)
return query_engine.get_results_tables(dataset)
else:
generator = get_dummy_data_generator(dataset, dummy_data_config)
return generator.get_results()
return generator.get_results_tables()


def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ):
Expand Down Expand Up @@ -175,17 +175,20 @@ def dump_dataset_sql(


def get_sql_strings(query_engine, dataset):
results_query = query_engine.get_query(dataset)
setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query)
results_queries = query_engine.get_queries(dataset)
setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries)
dialect = query_engine.sqlalchemy_dialect()
sql_strings = []

for i, query in enumerate(setup_queries, start=1):
sql = clause_as_str(query, dialect)
sql_strings.append(f"-- Setup query {i:03} / {len(setup_queries):03}\n{sql}")

sql = clause_as_str(results_query, dialect)
sql_strings.append(f"-- Results query\n{sql}")
for i, query in enumerate(results_queries, start=1):
sql = clause_as_str(query, dialect)
sql_strings.append(
f"-- Results query {i:03} / {len(results_queries):03}\n{sql}"
)

for i, query in enumerate(cleanup_queries, start=1):
sql = clause_as_str(query, dialect)
Expand Down
43 changes: 40 additions & 3 deletions ehrql/query_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from typing import Any

from ehrql.query_model import nodes as qm
from ehrql.utils.itertools_utils import iter_groups


class Marker: ...


class BaseQueryEngine:
Expand All @@ -12,6 +16,9 @@ class BaseQueryEngine:
flavour of tables and query language (SQL, pandas dataframes etc).
"""

# Setinel value used to mark the start of a new results table in a stream of results
RESULTS_START = Marker()

def __init__(self, dsn: str, backend: Any = None, config: dict | None = None):
"""
`dsn` is Data Source Name — a string (usually a URL) which provides connection
Expand All @@ -25,12 +32,42 @@ def __init__(self, dsn: str, backend: Any = None, config: dict | None = None):
self.backend = backend
self.config = config or {}

def get_results(self, dataset: qm.Dataset) -> Iterator[Sequence]:
def get_results_tables(self, dataset: qm.Dataset) -> Iterator[Iterator[Sequence]]:
"""
Given a query model `Dataset` return an iterator of "results tables", where each
table is an iterator of rows (usually tuples, but any sequence type will do)

This is the primary interface to query engines and the one required method.

Typically however, query engine subclasses will implement `get_results_stream`
instead which yields a flat sequence of rows, with tables separated by the
`RESULTS_START` marker value. This is converted into the appropriate structure
by `iter_groups` which also enforces that the caller interacts with it safely.
"""
Given a query model `Dataset` return the results as an iterator of "rows" (which
are usually tuples, but any sequence type will do)
return iter_groups(self.get_results_stream(dataset), self.RESULTS_START)

def get_results_stream(self, dataset: qm.Dataset) -> Iterator[Sequence | Marker]:
"""
Given a query model `Dataset` return an iterator of "results tables", where each
table is an iterator of rows (usually tuples, but any sequence type will do)

Override this method to do the things necessary to generate query code and
execute it against a particular backend.

Emitting results in a flat sequence like this with separators between the tables
ends up making the query code _much_ easier to reason about because everything
happens in a clear linear sequence rather than inside nested generators. This
makes things like transaction management and error handling much more
straightforward.
"""
raise NotImplementedError()

def get_results(self, dataset: qm.Dataset) -> Iterator[Sequence]:
"""
Temporary method to continue to support code which assumes only a single results
table
"""
tables = self.get_results_tables(dataset)
yield from next(tables)
for remaining in tables:
assert False, "Expected only one results table"
41 changes: 23 additions & 18 deletions ehrql/query_engines/base_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def get_next_id(self):
self.counter += 1
return self.counter

def get_query(self, dataset):
def get_queries(self, dataset):
"""
Return the SQL query to fetch the results for `dataset`
Return the SQL queries to fetch the results for `dataset`

Note that this query might make use of intermediate tables. The SQL queries
needed to create these tables and clean them up can be retrieved by calling
Expand Down Expand Up @@ -126,7 +126,9 @@ def get_query(self, dataset):
self.get_sql.cache_clear()
self.get_table.cache_clear()

return query
# At the moment we only support a single results table and so we'll only ever
# have a single query
return [query]

def select_patient_id_for_population(self, population_expression):
"""
Expand Down Expand Up @@ -826,26 +828,29 @@ def get_select_query_for_node_domain(self, node):
query = query.where(sqlalchemy.and_(*where_clauses))
return query

def get_results(self, dataset):
results_query = self.get_query(dataset)
setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query)
def get_results_stream(self, dataset):
results_queries = self.get_queries(dataset)
setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries)

with self.engine.connect() as connection:
for i, setup_query in enumerate(setup_queries, start=1):
log.info(f"Running setup query {i:03} / {len(setup_queries):03}")
connection.execute(setup_query)

log.info("Fetching results")
cursor_result = connection.execute(results_query)
try:
yield from cursor_result
except Exception: # pragma: no cover
# If we hit an error part way through fetching results then we should
# close the cursor to make it clear we're not going to be fetching any
# more (only really relevant for the in-memory SQLite tests, but good
# hygiene in any case)
cursor_result.close()
# Make sure the cleanup happens before raising the error
raise
for i, results_query in enumerate(results_queries, start=1):
log.info(f"Fetching results {i:03} / {len(setup_queries):03}")
cursor_result = connection.execute(results_query)
yield self.RESULTS_START
try:
yield from cursor_result
except Exception: # pragma: no cover
# If we hit an error part way through fetching results then we should
# close the cursor to make it clear we're not going to be fetching any
# more (only really relevant for the in-memory SQLite tests, but good
# hygiene in any case)
cursor_result.close()
# Make sure the cleanup happens before raising the error
raise

for i, cleanup_query in enumerate(cleanup_queries, start=1):
log.info(f"Running cleanup query {i:03} / {len(cleanup_queries):03}")
Expand Down
Loading
Loading