Skip to content

Commit

Permalink
Handle LT queries with year addition
Browse files Browse the repository at this point in the history
  • Loading branch information
rebkwok authored and DRMacIver committed Oct 22, 2024
1 parent b862186 commit 88572ed
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 16 deletions.
41 changes: 28 additions & 13 deletions ehrql/dummy_data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,34 @@ def generate_patient_facts(self, patient_id):
self.rnd.seed(f"{self.random_seed}:{patient_id}")
# TODO: We could obviously generate more realistic age distributions than this

dob_column = self.get_patient_column("date_of_birth")
if dob_column is not None and dob_column.get_constraint(
Constraint.GeneralRange
):
self.events_start = self.today - timedelta(days=120 * 365)
self.events_end = self.today
date_of_birth = self.get_random_value(dob_column)
else:
date_of_birth = self.today - timedelta(
days=self.rnd.randrange(0, 120 * 365)
)
age_days = self.rnd.randrange(105 * 365)
date_of_death = date_of_birth + timedelta(days=age_days)
while True:
# Retry until we have a date of birth and date of death that are
# within reasonable ranges
dob_column = self.get_patient_column("date_of_birth")
if dob_column is not None and dob_column.get_constraint(
Constraint.GeneralRange
):
self.events_start = self.today - timedelta(days=120 * 365)
self.events_end = self.today
date_of_birth = self.get_random_value(dob_column)
else:
date_of_birth = self.today - timedelta(
days=self.rnd.randrange(0, 120 * 365)
)

dod_column = self.get_patient_column("date_of_death")
if dod_column is not None and dod_column.get_constraint(
Constraint.GeneralRange
):
date_of_death = self.get_random_value(dod_column)
else:
age_days = self.rnd.randrange(105 * 365)
date_of_death = date_of_birth + timedelta(days=age_days)

if date_of_death >= date_of_birth and (
date_of_death - date_of_birth < timedelta(105 * 365)
):
break

self.date_of_birth = date_of_birth
self.date_of_death = date_of_death if date_of_death < self.today else None
Expand Down
21 changes: 20 additions & 1 deletion ehrql/dummy_data/query_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,26 @@ def query_to_column_constraints(query):
return {
column: [
Constraint.GeneralRange(
minimum=reference_date - timedelta(days=365 * difference)
minimum=reference_date.replace(
year=reference_date.year - difference
)
)
]
}
case Function.LT(
lhs=Function.DateAddYears(
lhs=SelectColumn() as column,
rhs=Value(value=difference),
),
rhs=Value(value=reference_date),
):
return {
column: [
Constraint.GeneralRange(
maximum=reference_date.replace(
year=reference_date.year - difference
),
includes_maximum=False,
)
]
}
Expand Down
15 changes: 14 additions & 1 deletion tests/unit/dummy_data/test_query_to_constraints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import date

from ehrql import create_dataset
from ehrql import create_dataset, years
from ehrql.dummy_data.query_info import (
normalize_constraints,
query_to_column_constraints,
Expand Down Expand Up @@ -74,3 +74,16 @@ def test_or_query_does_not_includes_constraints_on_only_one_size():
constraints = query_to_column_constraints(variable_definitions["population"])

assert len(constraints) == 0


def test_gt_query_with_date_addition():
dataset = create_dataset()

index_date = date(2022, 3, 1)
died_more_than_10_years_ago = (patients.date_of_death + years(10)) < index_date
dataset.define_population(died_more_than_10_years_ago)

variable_definitions = compile(dataset)
constraints = query_to_column_constraints(variable_definitions["population"])

assert len(constraints) == 1
33 changes: 32 additions & 1 deletion tests/unit/dummy_data/test_specific_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hypothesis import example, given
from hypothesis import strategies as st

from ehrql import create_dataset
from ehrql import create_dataset, years
from ehrql.dummy_data.generator import DummyDataGenerator
from ehrql.query_language import compile
from ehrql.tables.core import patients
Expand Down Expand Up @@ -220,3 +220,34 @@ def test_combined_age_range_in_one_shot(patched_time, query, target_size):
data_for_table = list(data.values())[0]
# Within that table expecting we generated a full population
assert len(data_for_table) == target_size


@mock.patch("ehrql.dummy_data.generator.time")
def test_date_arithmetic_comparison(patched_time):
dataset = create_dataset()

index_date = date(2022, 3, 1)
died_more_than_10_years_ago = (patients.date_of_death + years(10)) < index_date
dataset.define_population(died_more_than_10_years_ago)
dataset.date_of_birth = patients.date_of_birth
dataset.date_of_death = patients.date_of_death

target_size = 1000

variable_definitions = compile(dataset)
generator = DummyDataGenerator(variable_definitions, population_size=target_size)
generator.batch_size = target_size
generator.timeout = 10

# Configure `time.time()` so we timeout after one loop pass, as we
# should be able to generate these correctly in the first pass.
patched_time.time.side_effect = [0.0, 20.0]
data = generator.get_data()

# Expecting a single table
assert len(data) == 1
data_for_table = list(data.values())[0]
# Confirm that all patients have date of birth before date of death
assert all(row[1] <= row[2] for row in data_for_table)
# Within that table expecting we generated a full population
assert len(data_for_table) == target_size

0 comments on commit 88572ed

Please sign in to comment.