Skip to content

Commit

Permalink
feat: Preserve frame types though ehrQL operations
Browse files Browse the repository at this point in the history
We define tables in ehrQL using subclasses of `EventFrame` and
`PatientFrame` e.g.
```py
@table
class event(EventFrame):
  date = Series(datetime.date)
  code = Series(SNOMEDCTCode)
```

This gives us two big advantages:
 * auto-complete for column names;
 * the option to define [table-specific helper][1] methods.

Previously however, as soon as you performed any kind of operation on
one of these tables (i.e. called a method) you'd get back a plain
`EventFrame` or `PatientFrame` with no column auto-completion and no
helper methods.

This PR ensures that we return the appropriate frame subclass from all
methods. This also allows us to remove the `__getattr__` magic from the
`BaseFrame` class.

**The nasty bits**

Returning an appropriate type in all cases requires two bits of trickery
in the form of dynamic class compilation:

 1. When we call `sort_by()` we need the result to have, as well as its
    existing methods, the `get_first/last_for_patient()` methods. So we
    construct a subclass which mixes in the necessary methods.

 2. When we call one of the `get_first/last_for_patient()` methods we
    need to get back a `PatientFrame`. This should have all the columns
    defined on the original frame, but none of the methods. We introspect
    the class definition to extract all the columns and construct a new
    `PatientFrame` with those columns included.

**Static auto-complete**

The above gives us auto-complete in a dynamic context like an IPython
session where code is actually executed. We also get a limited form of
static (type-based) auto-complete in VSCode. Previously, this worked
only on the original frame and this PR extends this so that it persists
through `where/except_where` calls. However it won't persist through
`sort_by` or `get_first/last_for_patient()`.

After reasonably extensive investigation (which I need to write up in
[this ticket][2]) I don't _think_ our ideal behaviour is acheivable in
Pylance (VSCode's type checker) as things currently stand. But I don't
think anything in this PR makes things worse in that regard.

[1]: #1021
[2]: #506
  • Loading branch information
evansd committed Mar 14, 2023
1 parent 492a56f commit 4ce9b66
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
45 changes: 35 additions & 10 deletions databuilder/query_language.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
import datetime
import enum
import functools
import warnings
from collections import ChainMap
from typing import Union

from databuilder.codes import BaseCode
Expand Down Expand Up @@ -606,12 +608,6 @@ class BaseFrame:
def __init__(self, qm_node):
self.qm_node = qm_node

def __getattr__(self, name):
if not name.startswith("__"):
return self._select_column(name)
else:
raise AttributeError(f"object has no attribute {name!r}")

def _select_column(self, name):
return _wrap(qm.SelectColumn(source=self.qm_node, name=name))

Expand Down Expand Up @@ -666,27 +662,56 @@ def sort_by(self, *order_series):
source=qm_node,
sort_by=_convert(series),
)
return SortedEventFrame(qm_node)
cls = make_sorted_event_frame_class(self.__class__)
return cls(qm_node)


class SortedEventFrame(EventFrame):
class SortedEventFrameMethods:
def first_for_patient(self):
return PatientFrame(
cls = make_patient_frame_class(self.__class__)
return cls(
qm.PickOneRowPerPatient(
position=qm.Position.FIRST,
source=self.qm_node,
)
)

def last_for_patient(self):
return PatientFrame(
cls = make_patient_frame_class(self.__class__)
return cls(
qm.PickOneRowPerPatient(
position=qm.Position.LAST,
source=self.qm_node,
)
)


@functools.cache
def make_sorted_event_frame_class(cls):
"""
Given a class return a subclass which has the SortedEventFrameMethods
"""
if issubclass(cls, SortedEventFrameMethods):
return cls
else:
return type(cls.__name__, (SortedEventFrameMethods, cls), {})


@functools.cache
def make_patient_frame_class(cls):
"""
Given an EventFrame subclass return a PatientFrame subclass with the same columns as
the original frame
"""
# Because `Series` is a descriptor we can't access the column objects via class
# attributes without invoking the descriptor: instead, we have to access them using
# `vars()`. But `vars()` only gives us attributes defined directly on the class, not
# inherited ones. So we reproduced the inheritance behaviour using `ChainMap`.
attrs = ChainMap(*[vars(base) for base in cls.__mro__])
columns = {key: value for key, value in attrs.items() if isinstance(value, Series)}
return type(cls.__name__, (PatientFrame,), columns)


# FRAME CONSTRUCTOR ENTRYPOINTS
#

Expand Down
42 changes: 40 additions & 2 deletions tests/unit/test_query_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,26 @@
Value,
)


@table
class patients(PatientFrame):
date_of_birth = Series(date)
i = Series(int)
f = Series(float)


patients_schema = TableSchema(
date_of_birth=Column(date), i=Column(int), f=Column(float)
)
patients = PatientFrame(SelectPatientTable("patients", patients_schema))


@table
class events(EventFrame):
event_date = Series(date)
f = Series(float)


events_schema = TableSchema(event_date=Column(date), f=Column(float))
events = EventFrame(SelectTable("coded_events", events_schema))


def test_dataset():
Expand Down Expand Up @@ -526,3 +540,27 @@ class p(PatientFrame):
# Test invalid codes are rejected
with pytest.raises(ValueError, match="Invalid SNOMEDCTCode"):
p.code == "abc"


def test_frame_classes_are_preserved():
@table
class e(EventFrame):
start_date = Series(date)

def after_2020(self):
return self.where(self.start_date > "2020-01-01")

# Check that the helper method is preserved through `where`
filtered_frame = e.where(e.start_date > "1990-01-01")
assert isinstance(filtered_frame.after_2020(), EventFrame)

# Check that the helper method is preserved through `sort_by`
sorted_frame = e.sort_by(e.start_date)
assert isinstance(sorted_frame.after_2020(), EventFrame)

# Check that helper method is not available on patient frame
latest_event = sorted_frame.last_for_patient()
assert not hasattr(latest_event, "after_2020")

# Check that column is still available
assert "start_date" in dir(latest_event)

0 comments on commit 4ce9b66

Please sign in to comment.