Skip to content

Commit

Permalink
Initial prototype of getting specific column constraints out of the p…
Browse files Browse the repository at this point in the history
…opulation definition
  • Loading branch information
DRMacIver committed Oct 22, 2024
1 parent 6ffc98f commit b862186
Show file tree
Hide file tree
Showing 7 changed files with 597 additions and 7 deletions.
59 changes: 58 additions & 1 deletion ehrql/dummy_data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,30 @@ def get_patient_data(self, patient_id, table_names):
]
return data

def get_patient_column(self, column_name):
for table_name in self.query_info.population_table_names:
try:
return self.query_info.tables[table_name].columns[column_name]
except KeyError:
pass

def generate_patient_facts(self, patient_id):
# Seed the random generator using the patient_id so we always generate the same
# data for the same patient
self.rnd.seed(f"{self.random_seed}:{patient_id}")
# TODO: We could obviously generate more realistic age distributions than this
date_of_birth = self.today - timedelta(days=self.rnd.randrange(0, 120 * 365))

dob_column = self.get_patient_column("date_of_birth")
if dob_column is not None and dob_column.get_constraint(
Constraint.GeneralRange
):
self.events_start = self.today - timedelta(days=120 * 365)
self.events_end = self.today
date_of_birth = self.get_random_value(dob_column)
else:
date_of_birth = self.today - timedelta(
days=self.rnd.randrange(0, 120 * 365)
)
age_days = self.rnd.randrange(105 * 365)
date_of_death = date_of_birth + timedelta(days=age_days)

Expand Down Expand Up @@ -235,6 +253,45 @@ def get_random_value(self, column_info):
range_constraint.maximum + 1,
range_constraint.step,
)
elif (column_info.type is date) and (
date_range_constraint := column_info.get_constraint(Constraint.GeneralRange)
):
if date_range_constraint.maximum is not None:
maximum = date_range_constraint.maximum
else:
maximum = self.today

if not date_range_constraint.includes_maximum:
maximum -= timedelta(days=1)
if date_range_constraint.minimum is not None:
minimum = date_range_constraint.minimum
if not date_range_constraint.includes_minimum:
minimum += timedelta(days=1)
# TODO: Currently this code only runs when the column is date_of_birth
# so condition is always hit. Remove this pragma when that stops being
# the case.
if column_info.get_constraint(
Constraint.FirstOfMonth
): # pragma: no branch
if minimum.month == 12:
minimum = minimum.replace(year=minimum.year + 1, month=1, day=1)
else:
minimum = minimum.replace(month=minimum.month + 1, day=1)
else:
minimum = (maximum - timedelta(days=100 * 365)).replace(day=1)

assert minimum <= maximum

days = (maximum - minimum).days
result = minimum + timedelta(days=random.randint(0, days))
# TODO: Currently this code only runs when the column is date_of_birth
# so condition is always hit. Remove this pragma when that stops being
# the case.
if column_info.get_constraint(Constraint.FirstOfMonth): # pragma: no branch
assert minimum.day == 1
result = result.replace(day=1)
assert minimum <= result <= maximum
return result
elif column_info.values_used:
if self.rnd.randint(0, len(column_info.values_used)) != 0:
return self.rnd.choice(column_info.values_used)
Expand Down
164 changes: 160 additions & 4 deletions ehrql/dummy_data/query_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
from collections import defaultdict
from functools import cached_property
from datetime import date, timedelta
from functools import cached_property, lru_cache

from ehrql.query_model.introspection import all_unique_nodes, get_table_nodes
from ehrql.query_model.nodes import (
Expand All @@ -14,6 +15,7 @@
Value,
get_root_frame,
)
from ehrql.query_model.table_schema import Constraint


@dataclasses.dataclass
Expand All @@ -28,11 +30,17 @@ class ColumnInfo:
_values_used: set = dataclasses.field(default_factory=set)

@classmethod
def from_column(cls, name, column):
def from_column(cls, name, column, extra_constraints=()):
type_ = column.type_
if hasattr(type_, "_primitive_type"):
type_ = type_._primitive_type()
return cls(name, type_, constraints=tuple(column.constraints))
return cls(
name,
type_,
constraints=normalize_constraints(
tuple(column.constraints) + tuple(extra_constraints)
),
)

def __post_init__(self):
self._constraints_by_type = {type(c): c for c in self.constraints}
Expand Down Expand Up @@ -101,6 +109,10 @@ def from_variable_definitions(cls, variable_definitions):
all_nodes = all_unique_nodes(*variable_definitions.values())
by_type = get_nodes_by_type(all_nodes)

extra_constraints = query_to_column_constraints(
variable_definitions["population"]
)

tables = {
# Create a TableInfo object …
table.name: TableInfo.from_table(table)
Expand Down Expand Up @@ -128,7 +140,9 @@ def from_variable_definitions(cls, variable_definitions):
if column_info is None:
# … insert a ColumnInfo object into the appropriate table
column_info = ColumnInfo.from_column(
name, table.schema.get_column(name)
name,
table.schema.get_column(name),
extra_constraints=extra_constraints.get(column, ()),
)
table_info.columns[name] = column_info
# Record the ColumnInfo object associated with each SelectColumn node
Expand Down Expand Up @@ -176,3 +190,145 @@ def get_nodes_by_type(nodes):

def sort_by_name(iterable):
return sorted(iterable, key=lambda i: i.name)


@lru_cache
def query_to_column_constraints(query):
"""Converts a query (typically a population definition) into
constraints that would have to be applied to a record in order
to satisfy it."""
match query:
case Function.And(lhs=lhs, rhs=rhs):
left = query_to_column_constraints(lhs)
right = query_to_column_constraints(rhs)
keys = set(left) | set(right)
return {k: left.get(k, []) + right.get(k, []) for k in keys}
case Function.Or(lhs=lhs, rhs=rhs):
left = query_to_column_constraints(lhs)
right = query_to_column_constraints(rhs)
result = {}
for k, v in left.items():
try:
result[k] = list(set(v) & set(right[k]))
except KeyError:
pass
for k, v in list(result.items()):
if not v:
del result[k]
return result
case Function.EQ(
lhs=SelectColumn() as lhs,
rhs=Value(value=value),
):
return {lhs: [Constraint.Categorical(values=(value,))]}
case Function.EQ(
lhs=Function.YearFromDate(source=SelectColumn() as column),
rhs=Value(value=year),
):
return {
column: [
Constraint.GeneralRange(
minimum=date(year, 1, 1),
maximum=date(year, 12, 31),
)
]
}
case Function.In(
lhs=SelectColumn() as lhs,
rhs=Value(value=values),
):
return {lhs: [Constraint.Categorical(values=values)]}
case Function.GE(
lhs=Function.DateDifferenceInYears(
lhs=Value(value=reference_date), rhs=column
),
rhs=Value(value=difference),
):
return {
column: [
Constraint.GeneralRange(
maximum=reference_date - timedelta(days=365 * difference)
)
]
}
case Function.LE(
lhs=Function.DateDifferenceInYears(
lhs=Value(value=reference_date), rhs=column
),
rhs=Value(value=difference),
):
return {
column: [
Constraint.GeneralRange(
minimum=reference_date - timedelta(days=365 * difference)
)
]
}
case Function.GT(lhs=SelectColumn() as column, rhs=Value(value=min_value)):
return {
column: [
Constraint.GeneralRange(minimum=min_value, includes_minimum=False)
]
}
case Function.GE(lhs=SelectColumn() as column, rhs=Value(value=min_value)):
return {
column: [
Constraint.GeneralRange(minimum=min_value, includes_minimum=True)
]
}
case Function.LT(lhs=SelectColumn() as column, rhs=Value(value=max_value)):
return {
column: [
Constraint.GeneralRange(maximum=max_value, includes_maximum=False)
]
}
case Function.LE(lhs=SelectColumn() as column, rhs=Value(value=max_value)):
return {
column: [
Constraint.GeneralRange(maximum=max_value, includes_maximum=True)
]
}
case Function.IsNull(source=SelectColumn() as column):
return {column: [Constraint.NotNull()]}

return {}


def normalize_constraints(constraints):
group_by_type = defaultdict(list)
for constraint in constraints:
group_by_type[type(constraint)].append(constraint)
if len(group_by_type[Constraint.Categorical]) > 1:
constraint, *rest = group_by_type[Constraint.Categorical]
for more in rest:
constraint = Constraint.Categorical(
values=set(constraint.values) & set(more.values)
)
group_by_type[Constraint.Categorical] = [constraint]
if len(ranges := group_by_type[Constraint.GeneralRange]) > 1:
minimum = None
maximum = None
for r in ranges:
if minimum is None:
minimum = r.minimum
elif r.minimum is not None:
minimum = max(minimum, r.minimum)
if maximum is None:
maximum = r.maximum
elif r.maximum is not None:
maximum = min(maximum, r.maximum)

includes_minimum = minimum is None or all(r.validate(minimum) for r in ranges)
includes_maximum = maximum is None or all(r.validate(maximum) for r in ranges)
group_by_type[Constraint.GeneralRange] = [
Constraint.GeneralRange(
minimum=minimum,
maximum=maximum,
includes_maximum=includes_maximum,
includes_minimum=includes_minimum,
)
]

return tuple(
[constraint for group in group_by_type.values() for constraint in group]
)
41 changes: 41 additions & 0 deletions ehrql/query_model/table_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
from re import match
from typing import Any

from ehrql.utils.regex_utils import validate_regex

Expand Down Expand Up @@ -72,6 +73,46 @@ def description(self):
def validate(self, value):
return self.minimum <= value <= self.maximum if value is not None else True

class GeneralRange(BaseConstraint):
minimum: Any = None
maximum: Any = None

includes_minimum: bool = True
includes_maximum: bool = True

@property
def description(self):
parts = []
if self.minimum is not None:
if self.includes_minimum:
parts.append(f">= {self.minimum}")
else:
parts.append(f"> {self.minimum}")
if self.maximum is not None:
if self.includes_maximum:
parts.append(f"<= {self.maximum}")
else:
parts.append(f"< {self.maximum}")
if parts:
return "Always " + ", ".join(parts)
else:
return "Any value"

def validate(self, value):
if value is None:
return True
if self.minimum is not None:
if self.minimum > value:
return False
if self.minimum == value:
return self.includes_minimum
if self.maximum is not None:
if self.maximum < value:
return False
if self.maximum == value:
return self.includes_maximum
return True


@dataclasses.dataclass(frozen=True)
class Column:
Expand Down
9 changes: 7 additions & 2 deletions tests/generative/test_query_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@
sqla_metadata,
) = data_setup.setup(schema, num_patient_tables=2, num_event_tables=2)

_singleton_pprinters[id(schema)] = lambda obj, p, cycle: p.text("schema")
# This will only get run during a failing example, so shows up as uncovered when the tests pass.
_singleton_pprinters[id(schema)] = lambda obj, p, cycle: p.text(
"schema"
) # pragma: no cover

# Use the same strategies for values both for query generation and data generation.
value_strategies = {
Expand Down Expand Up @@ -115,7 +118,9 @@ class EnabledTests(Enum):
all_population = auto()


if TEST_NAMES_TO_RUN := set(os.environ.get("GENTEST_TESTS_TO_RUN", "").lower().split()):
if TEST_NAMES_TO_RUN := set(
os.environ.get("GENTEST_TESTS_TO_RUN", "").lower().split()
): # pragma: no cover
TESTS_TO_RUN = [t for t in EnabledTests if t.name in TEST_NAMES_TO_RUN]
else:
TESTS_TO_RUN = list(EnabledTests)
Expand Down
Loading

0 comments on commit b862186

Please sign in to comment.