diff --git a/databuilder/query_language.py b/databuilder/query_language.py index f70b7142c..917d5823d 100644 --- a/databuilder/query_language.py +++ b/databuilder/query_language.py @@ -39,7 +39,7 @@ def __setattr__(self, name, value): def compile(dataset): # noqa A003 - return {k: v.qm_node for k, v in vars(dataset).items() if isinstance(v, Series)} + return {k: v.qm_node for k, v in vars(dataset).items() if isinstance(v, BaseSeries)} # BASIC SERIES TYPES @@ -47,7 +47,7 @@ def compile(dataset): # noqa A003 @dataclasses.dataclass(frozen=True) -class Series: +class BaseSeries: qm_node: qm.Node def __hash__(self): @@ -90,7 +90,7 @@ def map_values(self, mapping): return _apply(qm.Case, cases) -class EventSeries(Series): +class EventSeries(BaseSeries): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # Register the series using its `_type` attribute @@ -100,7 +100,7 @@ def __init_subclass__(cls, **kwargs): # they would be defined here as well -class PatientSeries(Series): +class PatientSeries(BaseSeries): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # Register the series using its `_type` attribute @@ -348,7 +348,7 @@ def _convert(arg): if isinstance(arg, _DictArg): return {_convert(key): _convert(value) for key, value in arg} # If it's an ehrQL series then get the wrapped query model node - elif isinstance(arg, Series): + elif isinstance(arg, BaseSeries): return arg.qm_node # If it's a Codelist extract the set of codes and put it in a Value wrapper elif isinstance(arg, Codelist): @@ -362,11 +362,14 @@ def _convert(arg): # -class Frame: +class BaseFrame: def __init__(self, qm_node): self.qm_node = qm_node def __getattr__(self, name): + return self._select_column(name) + + def _select_column(self, name): return _wrap(qm.SelectColumn(source=self.qm_node, name=name)) def exists_for_patient(self): @@ -376,11 +379,11 @@ def count_for_patient(self): return _wrap(qm.AggregateByPatient.Count(source=self.qm_node)) -class PatientFrame(Frame): +class PatientFrame(BaseFrame): pass -class EventFrame(Frame): +class EventFrame(BaseFrame): def take(self, series): return EventFrame( qm.Filter( @@ -413,7 +416,7 @@ def sort_by(self, *order_series): return SortedEventFrame(qm_node) -class SortedEventFrame(Frame): +class SortedEventFrame(BaseFrame): def first_for_patient(self): return PatientFrame( qm.PickOneRowPerPatient( @@ -435,20 +438,56 @@ def last_for_patient(self): # -def build_patient_table(name, schema, contract=None): - if contract is not None: - contract.validate_schema(schema) - return PatientFrame( - qm.SelectPatientTable(name, schema=qm.TableSchema(schema)), - ) +class SchemaError(Exception): + ... + + +# A class decorator which replaces the class definition with an appropriately configured +# instance of the class. Obviously this is a _bit_ odd, but I think worth it overall. +# Using classes to define tables is (as far as I can tell) the only way to get nice +# autocomplete and type-checking behaviour for column names. But we don't actually want +# these classes accessible anywhere: users should only be interacting with instances of +# the classes, and having the classes themselves in the module namespaces only makes +# autocomplete more confusing and error prone. +def construct(cls): + try: + qm_class = { + (PatientFrame,): qm.SelectPatientTable, + (EventFrame,): qm.SelectTable, + }[cls.__bases__] + except KeyError: + raise SchemaError( + "Schema class must subclass either `PatientFrame` or `EventFrame`" + ) + + table_name = cls.__name__ + # Get all `Series` objects on the class and determine the schema from them + schema = { + series.name: series.type_ + for series in vars(cls).values() + if isinstance(series, Series) + } + + qm_node = qm_class(table_name, qm.TableSchema(schema)) + return cls(qm_node) + + +# A descriptor which will return the appropriate type of series depending on the type of +# frame it belongs to i.e. a PatientSeries subclass for PatientFrames and an EventSeries +# subclass for EventFrames. This lets schema authors use a consistent syntax when +# defining frames of either type. +class Series: + def __init__(self, type_): + self.type_ = type_ + def __set_name__(self, owner, name): + self.name = name -def build_event_table(name, schema, contract=None): - if contract is not None: # pragma: no cover - contract.validate_schema(schema) - return EventFrame( - qm.SelectTable(name, schema=qm.TableSchema(schema)), - ) + def __get__(self, instance, owner): + # Prevent users attempting to interact with the class rather than an instance + if instance is None: + raise SchemaError("Missing `@construct` decorator on schema class") + return instance._select_column(self.name) # CASE EXPRESSION FUNCTIONS diff --git a/databuilder/tables.py b/databuilder/tables.py index ce188151f..5f919e638 100644 --- a/databuilder/tables.py +++ b/databuilder/tables.py @@ -1,15 +1,10 @@ import datetime -from databuilder.query_language import build_patient_table +from databuilder.query_language import PatientFrame, Series, construct -from .contracts import universal -patients = build_patient_table( - "patients", - { - "date_of_birth": datetime.date, - "date_of_death": datetime.date, - "sex": str, - }, - contract=universal.Patients, -) +@construct +class patients(PatientFrame): + date_of_birth = Series(datetime.date) + date_of_death = Series(datetime.date) + sex = Series(str) diff --git a/tests/acceptance/comparative_booster_study/schema.py b/tests/acceptance/comparative_booster_study/schema.py index 9cd3683dd..ccd69c355 100644 --- a/tests/acceptance/comparative_booster_study/schema.py +++ b/tests/acceptance/comparative_booster_study/schema.py @@ -3,157 +3,172 @@ import sqlalchemy.orm from databuilder.codes import CTV3Code, ICD10Code, SNOMEDCTCode -from databuilder.query_language import build_event_table, build_patient_table +from databuilder.query_language import EventFrame, PatientFrame, Series, construct from ...lib.util import orm_class_from_table Base = sqlalchemy.orm.declarative_base() -patients = build_patient_table( - "patients", - { - "date_of_birth": datetime.date, - "sex": str, - }, -) +@construct +class patients(PatientFrame): + date_of_birth = Series(datetime.date) + sex = Series(str) + Patient = orm_class_from_table(Base, patients) -vaccinations = build_event_table( - "vaccinations", - { - "date": datetime.date, - "target_disease": str, - "product_name": str, - }, -) +@construct +class vaccinations(EventFrame): + date = Series(datetime.date) + target_disease = Series(str) + product_name = Series(str) + Vaccination = orm_class_from_table(Base, vaccinations) -practice_registrations = build_event_table( - "practice_registrations", - { - "start_date": datetime.date, - "end_date": datetime.date, - "practice_pseudo_id": int, - "practice_stp": str, - "practice_nuts1_region_name": str, - }, -) +@construct +class practice_registrations(EventFrame): + start_date = Series(datetime.date) + end_date = Series(datetime.date) + practice_pseudo_id = Series(int) + practice_stp = Series(str) + practice_nuts1_region_name = Series(str) + PracticeRegistration = orm_class_from_table(Base, practice_registrations) -ons_deaths = build_event_table( - "ons_deaths", - { - "date": datetime.date, - # TODO: Revisit this when we have support for multi-valued fields - **{f"cause_of_death_{i:02d}": ICD10Code for i in range(1, 16)}, - }, -) +@construct +class ons_deaths(EventFrame): + date = Series(datetime.date) + # TODO: Revisit this when we have support for multi-valued fields + cause_of_death_01 = Series(ICD10Code) + cause_of_death_02 = Series(ICD10Code) + cause_of_death_03 = Series(ICD10Code) + cause_of_death_04 = Series(ICD10Code) + cause_of_death_05 = Series(ICD10Code) + cause_of_death_06 = Series(ICD10Code) + cause_of_death_07 = Series(ICD10Code) + cause_of_death_08 = Series(ICD10Code) + cause_of_death_09 = Series(ICD10Code) + cause_of_death_10 = Series(ICD10Code) + cause_of_death_11 = Series(ICD10Code) + cause_of_death_12 = Series(ICD10Code) + cause_of_death_13 = Series(ICD10Code) + cause_of_death_14 = Series(ICD10Code) + cause_of_death_15 = Series(ICD10Code) + ONSDeath = orm_class_from_table(Base, ons_deaths) -coded_events = build_event_table( - "coded_events", - { - "date": datetime.date, - "snomedct_code": SNOMEDCTCode, - "ctv3_code": CTV3Code, - "numeric_value": float, - }, -) +@construct +class coded_events(EventFrame): + date = Series(datetime.date) + snomedct_code = Series(SNOMEDCTCode) + ctv3_code = Series(CTV3Code) + numeric_value = Series(float) + CodedEvent = orm_class_from_table(Base, coded_events) -medications = build_event_table( - "medications", - { - "date": datetime.date, - "snomedct_code": SNOMEDCTCode, - }, -) +@construct +class medications(EventFrame): + date = Series(datetime.date) + snomedct_code = Series(SNOMEDCTCode) + Medication = orm_class_from_table(Base, medications) -addresses = build_event_table( - "addresses", - { - "address_id": int, - "start_date": datetime.date, - "end_date": datetime.date, - "address_type": int, - "rural_urban_classification": int, - "imd_rounded": int, - "msoa_code": str, - # Is the address potentially a match for a care home? (Using TPP's algorithm) - "care_home_is_potential_match": bool, - # These two fields look like they should be a single boolean, but this is how - # they're represented in the data - "care_home_requires_nursing": bool, - "care_home_does_not_require_nursing": bool, - }, -) +@construct +class addresses(EventFrame): + address_id = Series(int) + start_date = Series(datetime.date) + end_date = Series(datetime.date) + address_type = Series(int) + rural_urban_classification = Series(int) + imd_rounded = Series(int) + msoa_code = Series(str) + # Is the address potentially a match for a care home? (Using TPP's algorithm) + care_home_is_potential_match = Series(bool) + # These two fields look like they should be a single boolean, but this is how + # they're represented in the data + care_home_requires_nursing = Series(bool) + care_home_does_not_require_nursing = Series(bool) + Address = orm_class_from_table(Base, addresses) -sgss_covid_all_tests = build_event_table( - "sgss_covid_all_tests", - { - "specimen_taken_date": datetime.date, - "is_positive": bool, - }, -) +@construct +class sgss_covid_all_tests(EventFrame): + specimen_taken_date = Series(datetime.date) + is_positive = Series(bool) + SGSSCovidAllTestsResult = orm_class_from_table(Base, sgss_covid_all_tests) -occupation_on_covid_vaccine_record = build_event_table( - "occupation_on_covid_vaccine_record", - { - "is_healthcare_worker": bool, - }, -) +@construct +class occupation_on_covid_vaccine_record(EventFrame): + is_healthcare_worker = Series(bool) + OccupationOnCovidVaccineRecord = orm_class_from_table( Base, occupation_on_covid_vaccine_record ) -emergency_care_attendances = build_event_table( - "emergency_care_attendances", - { - "id": int, - "arrival_date": datetime.date, - "discharge_destination": SNOMEDCTCode, - # TODO: Revisit this when we have support for multi-valued fields - **{f"diagnosis_{i:02d}": SNOMEDCTCode for i in range(1, 25)}, - }, -) +@construct +class emergency_care_attendances(EventFrame): + id = Series(int) # noqa: A003 + arrival_date = Series(datetime.date) + discharge_destination = Series(SNOMEDCTCode) + # TODO: Revisit this when we have support for multi-valued fields + diagnosis_01 = Series(SNOMEDCTCode) + diagnosis_02 = Series(SNOMEDCTCode) + diagnosis_03 = Series(SNOMEDCTCode) + diagnosis_04 = Series(SNOMEDCTCode) + diagnosis_05 = Series(SNOMEDCTCode) + diagnosis_06 = Series(SNOMEDCTCode) + diagnosis_07 = Series(SNOMEDCTCode) + diagnosis_08 = Series(SNOMEDCTCode) + diagnosis_09 = Series(SNOMEDCTCode) + diagnosis_10 = Series(SNOMEDCTCode) + diagnosis_11 = Series(SNOMEDCTCode) + diagnosis_12 = Series(SNOMEDCTCode) + diagnosis_13 = Series(SNOMEDCTCode) + diagnosis_14 = Series(SNOMEDCTCode) + diagnosis_15 = Series(SNOMEDCTCode) + diagnosis_16 = Series(SNOMEDCTCode) + diagnosis_17 = Series(SNOMEDCTCode) + diagnosis_18 = Series(SNOMEDCTCode) + diagnosis_19 = Series(SNOMEDCTCode) + diagnosis_20 = Series(SNOMEDCTCode) + diagnosis_21 = Series(SNOMEDCTCode) + diagnosis_22 = Series(SNOMEDCTCode) + diagnosis_23 = Series(SNOMEDCTCode) + diagnosis_24 = Series(SNOMEDCTCode) + EmergencyCareAttendance = orm_class_from_table(Base, emergency_care_attendances) -hospital_admissions = build_event_table( - "hospital_admissions", - { - "id": int, - "admission_date": datetime.date, - "discharge_date": datetime.date, - "admission_method": str, - # TODO: Revisit this when we have support for multi-valued fields - "all_diagnoses": str, - "patient_classification": str, - "days_in_critical_care": int, - }, -) +@construct +class hospital_admissions(EventFrame): + id = Series(int) # noqa: A003 + admission_date = Series(datetime.date) + discharge_date = Series(datetime.date) + admission_method = Series(str) + # TODO: Revisit this when we have support for multi-valued fields + all_diagnoses = Series(str) + patient_classification = Series(str) + days_in_critical_care = Series(int) + HospitalAdmission = orm_class_from_table(Base, hospital_admissions) diff --git a/tests/acceptance/comparative_booster_study/test_variables_lib.py b/tests/acceptance/comparative_booster_study/test_variables_lib.py index b303e6e03..7c85858f9 100644 --- a/tests/acceptance/comparative_booster_study/test_variables_lib.py +++ b/tests/acceptance/comparative_booster_study/test_variables_lib.py @@ -4,7 +4,7 @@ import pytest import sqlalchemy.orm -from databuilder.query_language import Dataset, build_event_table +from databuilder.query_language import Dataset, EventFrame, Series, construct from ...lib.util import orm_class_from_table from .variables_lib import create_sequential_variables @@ -12,10 +12,11 @@ @pytest.fixture def schema(): - events = build_event_table( - "events", - {"date": date, "value": int}, - ) + @construct + class events(EventFrame): + date = Series(date) + value = Series(int) + Event = orm_class_from_table(sqlalchemy.orm.declarative_base(), events) return SimpleNamespace(events=events, Event=Event) diff --git a/tests/integration/backends/test_base.py b/tests/integration/backends/test_base.py index b70d5b82d..6ecedf89f 100644 --- a/tests/integration/backends/test_base.py +++ b/tests/integration/backends/test_base.py @@ -6,7 +6,13 @@ from databuilder import sqlalchemy_types from databuilder.backends.base import BaseBackend, Column, MappedTable, QueryTable from databuilder.query_engines.base_sql import BaseSQLQueryEngine -from databuilder.query_language import Dataset, build_event_table, build_patient_table +from databuilder.query_language import ( + Dataset, + EventFrame, + PatientFrame, + Series, + construct, +) from ...lib.util import next_id @@ -14,14 +20,15 @@ # Simple schema to test against -patients = build_patient_table( - "patients", - {"date_of_birth": datetime.date}, -) -covid_tests = build_event_table( - "covid_tests", - {"date": datetime.date, "positive": int}, -) +@construct +class patients(PatientFrame): + date_of_birth = Series(datetime.date) + + +@construct +class covid_tests(EventFrame): + date = Series(datetime.date) + positive = Series(int) class TestBackend(BaseBackend): diff --git a/tests/spec/tables.py b/tests/spec/tables.py index cd72494e1..b5314a412 100644 --- a/tests/spec/tables.py +++ b/tests/spec/tables.py @@ -3,42 +3,41 @@ import sqlalchemy.orm from databuilder.codes import SNOMEDCTCode -from databuilder.query_language import build_event_table, build_patient_table +from databuilder.query_language import EventFrame, PatientFrame, Series, construct from ..lib.util import orm_class_from_table -p = build_patient_table( - "patient_level_table", - { - "i1": int, - "i2": int, - "b1": bool, - "b2": bool, - "c1": SNOMEDCTCode, - "d1": datetime.date, - "d2": datetime.date, - "s1": str, - "s2": str, - }, -) - - -e = build_event_table( - "event_level_table", - { - "i1": int, - "i2": int, - "b1": bool, - "b2": bool, - "c1": SNOMEDCTCode, - "d1": datetime.date, - "d2": datetime.date, - "s1": str, - "s2": str, - }, -) +@construct +class patient_level_table(PatientFrame): + i1 = Series(int) + i2 = Series(int) + b1 = Series(bool) + b2 = Series(bool) + c1 = Series(SNOMEDCTCode) + d1 = Series(datetime.date) + d2 = Series(datetime.date) + s1 = Series(str) + s2 = Series(str) + + +@construct +class event_level_table(EventFrame): + i1 = Series(int) + i2 = Series(int) + b1 = Series(bool) + b2 = Series(bool) + c1 = Series(SNOMEDCTCode) + d1 = Series(datetime.date) + d2 = Series(datetime.date) + s1 = Series(str) + s2 = Series(str) + + +# Define short aliases for terser tests +p = patient_level_table +e = event_level_table Base = sqlalchemy.orm.declarative_base() -PatientLevelTable = orm_class_from_table(Base, p) -EventLevelTable = orm_class_from_table(Base, e) +PatientLevelTable = orm_class_from_table(Base, patient_level_table) +EventLevelTable = orm_class_from_table(Base, event_level_table) diff --git a/tests/unit/contracts/test_base.py b/tests/unit/contracts/test_base.py index db110296a..02cba6712 100644 --- a/tests/unit/contracts/test_base.py +++ b/tests/unit/contracts/test_base.py @@ -1,3 +1,5 @@ +import datetime + import pytest from databuilder.backends.base import BaseBackend, Column, MappedTable @@ -88,3 +90,33 @@ class BadBackend(BaseBackend): match="Column date_of_birth is defined with an invalid type 'integer'.\n\nAllowed types are: date", ): PatientsContract.validate_implementation(BadBackend, "patients") + + +def test_validate_schema(): + PatientsContract.validate_schema( + { + "sex": str, + "date_of_birth": datetime.date, + } + ) + + +def test_validate_schema_wrong_columns(): + with pytest.raises(AssertionError): + PatientsContract.validate_schema( + { + "sex": str, + "date_of_birth": datetime.date, + "extra_column": int, + } + ) + + +def test_validate_schema_wrong_types(): + with pytest.raises(AssertionError): + PatientsContract.validate_schema( + { + "sex": str, + "date_of_birth": str, + } + ) diff --git a/tests/unit/test_query_language.py b/tests/unit/test_query_language.py index 999d169b3..15ad89ef2 100644 --- a/tests/unit/test_query_language.py +++ b/tests/unit/test_query_language.py @@ -5,9 +5,16 @@ from databuilder.query_language import ( Dataset, DateEventSeries, + EventFrame, IntEventSeries, - build_patient_table, + IntPatientSeries, + PatientFrame, + SchemaError, + Series, + StrEventSeries, + StrPatientSeries, compile, + construct, ) from databuilder.query_model import ( Function, @@ -19,10 +26,8 @@ Value, ) -patients_schema = { - "date_of_birth": date, -} -patients = build_patient_table("patients", patients_schema) +patients_schema = TableSchema(date_of_birth=date) +patients = PatientFrame(SelectPatientTable("patients", patients_schema)) def test_dataset(): @@ -141,3 +146,55 @@ def test_series_are_not_hashable(): int_series = IntEventSeries(qm_int_series) with pytest.raises(TypeError): {int_series: True} + + +# TEST CLASS-BASED FRAME CONSTRUCTOR +# + + +def test_construct_constructs_patient_frame(): + @construct + class some_table(PatientFrame): + some_int = Series(int) + some_str = Series(str) + + assert isinstance(some_table, PatientFrame) + assert some_table.qm_node.name == "some_table" + assert isinstance(some_table.some_int, IntPatientSeries) + assert isinstance(some_table.some_str, StrPatientSeries) + + +def test_construct_constructs_event_frame(): + @construct + class some_table(EventFrame): + some_int = Series(int) + some_str = Series(str) + + assert isinstance(some_table, EventFrame) + assert some_table.qm_node.name == "some_table" + assert isinstance(some_table.some_int, IntEventSeries) + assert isinstance(some_table.some_str, StrEventSeries) + + +def test_construct_enforces_correct_base_class(): + with pytest.raises(SchemaError, match="Schema class must subclass"): + + @construct + class some_table(Dataset): + some_int = Series(int) + + +def test_construct_enforces_exactly_one_base_class(): + with pytest.raises(SchemaError, match="Schema class must subclass"): + + @construct + class some_table(PatientFrame, Dataset): + some_int = Series(int) + + +def test_must_reference_instance_not_class(): + class some_table(PatientFrame): + some_int = Series(int) + + with pytest.raises(SchemaError, match="Missing `@construct` decorator"): + some_table.some_int