diff --git a/databuilder/query_language.py b/databuilder/query_language.py index 5e6f49488..1b3bb1848 100644 --- a/databuilder/query_language.py +++ b/databuilder/query_language.py @@ -1,7 +1,9 @@ import dataclasses import datetime import enum +import functools import warnings +from collections import ChainMap from typing import Union from databuilder.codes import BaseCode @@ -606,12 +608,6 @@ class BaseFrame: def __init__(self, qm_node): self.qm_node = qm_node - def __getattr__(self, name): - if not name.startswith("__"): - return self._select_column(name) - else: - raise AttributeError(f"object has no attribute {name!r}") - def _select_column(self, name): return _wrap(qm.SelectColumn(source=self.qm_node, name=name)) @@ -666,12 +662,14 @@ def sort_by(self, *order_series): source=qm_node, sort_by=_convert(series), ) - return SortedEventFrame(qm_node) + cls = make_sorted_event_frame_class(self.__class__) + return cls(qm_node) -class SortedEventFrame(EventFrame): +class SortedEventFrameMethods: def first_for_patient(self): - return PatientFrame( + cls = make_patient_frame_class(self.__class__) + return cls( qm.PickOneRowPerPatient( position=qm.Position.FIRST, source=self.qm_node, @@ -679,7 +677,8 @@ def first_for_patient(self): ) def last_for_patient(self): - return PatientFrame( + cls = make_patient_frame_class(self.__class__) + return cls( qm.PickOneRowPerPatient( position=qm.Position.LAST, source=self.qm_node, @@ -687,6 +686,32 @@ def last_for_patient(self): ) +@functools.cache +def make_sorted_event_frame_class(cls): + """ + Given a class return a subclass which has the SortedEventFrameMethods + """ + if issubclass(cls, SortedEventFrameMethods): + return cls + else: + return type(cls.__name__, (SortedEventFrameMethods, cls), {}) + + +@functools.cache +def make_patient_frame_class(cls): + """ + Given an EventFrame subclass return a PatientFrame subclass with the same columns as + the original frame + """ + # Because `Series` is a descriptor we can't access the column objects via class + # attributes without invoking the descriptor: instead, we have to access them using + # `vars()`. But `vars()` only gives us attributes defined directly on the class, not + # inherited ones. So we reproduced the inheritance behaviour using `ChainMap`. + attrs = ChainMap(*[vars(base) for base in cls.__mro__]) + columns = {key: value for key, value in attrs.items() if isinstance(value, Series)} + return type(cls.__name__, (PatientFrame,), columns) + + # FRAME CONSTRUCTOR ENTRYPOINTS # diff --git a/tests/unit/test_query_language.py b/tests/unit/test_query_language.py index 2ccc6a8d6..e24b6cf8e 100644 --- a/tests/unit/test_query_language.py +++ b/tests/unit/test_query_language.py @@ -44,12 +44,26 @@ Value, ) + +@table +class patients(PatientFrame): + date_of_birth = Series(date) + i = Series(int) + f = Series(float) + + patients_schema = TableSchema( date_of_birth=Column(date), i=Column(int), f=Column(float) ) -patients = PatientFrame(SelectPatientTable("patients", patients_schema)) + + +@table +class events(EventFrame): + event_date = Series(date) + f = Series(float) + + events_schema = TableSchema(event_date=Column(date), f=Column(float)) -events = EventFrame(SelectTable("coded_events", events_schema)) def test_dataset():