Skip to content

Commit

Permalink
Merge pull request #663 from opensafely-core/evansd/schema-definitions
Browse files Browse the repository at this point in the history
Add class-based schema definitions
  • Loading branch information
evansd authored Aug 9, 2022
2 parents ae390a7 + c05df39 commit 4e387e5
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 188 deletions.
81 changes: 60 additions & 21 deletions databuilder/query_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def __setattr__(self, name, value):


def compile(dataset): # noqa A003
return {k: v.qm_node for k, v in vars(dataset).items() if isinstance(v, Series)}
return {k: v.qm_node for k, v in vars(dataset).items() if isinstance(v, BaseSeries)}


# BASIC SERIES TYPES
#


@dataclasses.dataclass(frozen=True)
class Series:
class BaseSeries:
qm_node: qm.Node

def __hash__(self):
Expand Down Expand Up @@ -90,7 +90,7 @@ def map_values(self, mapping):
return _apply(qm.Case, cases)


class EventSeries(Series):
class EventSeries(BaseSeries):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Register the series using its `_type` attribute
Expand All @@ -100,7 +100,7 @@ def __init_subclass__(cls, **kwargs):
# they would be defined here as well


class PatientSeries(Series):
class PatientSeries(BaseSeries):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Register the series using its `_type` attribute
Expand Down Expand Up @@ -348,7 +348,7 @@ def _convert(arg):
if isinstance(arg, _DictArg):
return {_convert(key): _convert(value) for key, value in arg}
# If it's an ehrQL series then get the wrapped query model node
elif isinstance(arg, Series):
elif isinstance(arg, BaseSeries):
return arg.qm_node
# If it's a Codelist extract the set of codes and put it in a Value wrapper
elif isinstance(arg, Codelist):
Expand All @@ -362,11 +362,14 @@ def _convert(arg):
#


class Frame:
class BaseFrame:
def __init__(self, qm_node):
self.qm_node = qm_node

def __getattr__(self, name):
return self._select_column(name)

def _select_column(self, name):
return _wrap(qm.SelectColumn(source=self.qm_node, name=name))

def exists_for_patient(self):
Expand All @@ -376,11 +379,11 @@ def count_for_patient(self):
return _wrap(qm.AggregateByPatient.Count(source=self.qm_node))


class PatientFrame(Frame):
class PatientFrame(BaseFrame):
pass


class EventFrame(Frame):
class EventFrame(BaseFrame):
def take(self, series):
return EventFrame(
qm.Filter(
Expand Down Expand Up @@ -413,7 +416,7 @@ def sort_by(self, *order_series):
return SortedEventFrame(qm_node)


class SortedEventFrame(Frame):
class SortedEventFrame(BaseFrame):
def first_for_patient(self):
return PatientFrame(
qm.PickOneRowPerPatient(
Expand All @@ -435,20 +438,56 @@ def last_for_patient(self):
#


def build_patient_table(name, schema, contract=None):
if contract is not None:
contract.validate_schema(schema)
return PatientFrame(
qm.SelectPatientTable(name, schema=qm.TableSchema(schema)),
)
class SchemaError(Exception):
...


# A class decorator which replaces the class definition with an appropriately configured
# instance of the class. Obviously this is a _bit_ odd, but I think worth it overall.
# Using classes to define tables is (as far as I can tell) the only way to get nice
# autocomplete and type-checking behaviour for column names. But we don't actually want
# these classes accessible anywhere: users should only be interacting with instances of
# the classes, and having the classes themselves in the module namespaces only makes
# autocomplete more confusing and error prone.
def construct(cls):
try:
qm_class = {
(PatientFrame,): qm.SelectPatientTable,
(EventFrame,): qm.SelectTable,
}[cls.__bases__]
except KeyError:
raise SchemaError(
"Schema class must subclass either `PatientFrame` or `EventFrame`"
)

table_name = cls.__name__
# Get all `Series` objects on the class and determine the schema from them
schema = {
series.name: series.type_
for series in vars(cls).values()
if isinstance(series, Series)
}

qm_node = qm_class(table_name, qm.TableSchema(schema))
return cls(qm_node)


# A descriptor which will return the appropriate type of series depending on the type of
# frame it belongs to i.e. a PatientSeries subclass for PatientFrames and an EventSeries
# subclass for EventFrames. This lets schema authors use a consistent syntax when
# defining frames of either type.
class Series:
def __init__(self, type_):
self.type_ = type_

def __set_name__(self, owner, name):
self.name = name

def build_event_table(name, schema, contract=None):
if contract is not None: # pragma: no cover
contract.validate_schema(schema)
return EventFrame(
qm.SelectTable(name, schema=qm.TableSchema(schema)),
)
def __get__(self, instance, owner):
# Prevent users attempting to interact with the class rather than an instance
if instance is None:
raise SchemaError("Missing `@construct` decorator on schema class")
return instance._select_column(self.name)


# CASE EXPRESSION FUNCTIONS
Expand Down
17 changes: 6 additions & 11 deletions databuilder/tables.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import datetime

from databuilder.query_language import build_patient_table
from databuilder.query_language import PatientFrame, Series, construct

from .contracts import universal

patients = build_patient_table(
"patients",
{
"date_of_birth": datetime.date,
"date_of_death": datetime.date,
"sex": str,
},
contract=universal.Patients,
)
@construct
class patients(PatientFrame):
date_of_birth = Series(datetime.date)
date_of_death = Series(datetime.date)
sex = Series(str)
Loading

0 comments on commit 4e387e5

Please sign in to comment.