Skip to content

Commit

Permalink
Merge pull request #2255 from opensafely-core/DRMacIver/better-event-…
Browse files Browse the repository at this point in the history
…distribution

Better distribution of number of events
  • Loading branch information
DRMacIver authored Dec 4, 2024
2 parents aeaa61a + 797e12b commit 92fd4d6
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 4 deletions.
58 changes: 56 additions & 2 deletions ehrql/dummy_data_nextgen/generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import functools
import itertools
import logging
import math
import random
import string
import time
from bisect import bisect_left
from collections import defaultdict
from contextlib import contextmanager
from datetime import date, timedelta

Expand Down Expand Up @@ -87,6 +89,7 @@ def get_data(self):
generated += len(patient_batch)
database.populate(merge_table_data(*patient_batch.values()))
results = engine.get_results(population_query)
valid_patient_ids = set()
# Accumulate all data from matching patients, returning once we have enough
for row in results:
# Because of the existence of InlinePatientTables it's possible to get
Expand All @@ -95,6 +98,8 @@ def get_data(self):
if row.patient_id not in patient_batch:
continue

valid_patient_ids.add(row.patient_id)

extend_table_data(
data,
patient_batch[row.patient_id],
Expand All @@ -106,6 +111,45 @@ def get_data(self):
if found >= self.population_size:
break

# With each batch of patients we can look at what empirical
# characteristics patients that make it into the population definition
# have. We can then use this to inform how we generate patients for
# future batches by ensuring that they have all the characteristics
# we believe are necessary and none of the characteristics we believe
# are forbidden.
#
# In this particular case what we're doing is we're looking for
# which tables are needed to satisfy or block satisfaction of the
# population definition. For example, we might have a requirement that
# the patient has an asthma diagnosis. If so, the patient needs at least
# one clinical event to satisfy the population condition, so on future
# iterations we ensure all patients are drawn with at least one clinical event.
#
# Similarly it might be that actually any clinical event we generate will block
# the patient being generated. A population definition that says that the
# patient doesn't have asthma will do this, because the asthma code is often
# then the one we will generate, so any events will result in a patient not
# satisfying the population definition. In this case we mark the clinical_events
# table as forbidden and never generate clinical events in the dummy data.
if generator.required_tables is None and valid_patient_ids:
forbidden_tables = set(database.tables)
assert generator.forbidden_tables is None
tables_by_id = defaultdict(set)
for table, rows in database.tables.items():
for row in rows.to_records():
tables_by_id[row["patient_id"]].add(table)
required_tables = None
for patient_id in valid_patient_ids:
tables = tables_by_id[patient_id]
forbidden_tables -= tables
if required_tables is None:
required_tables = set(tables)
else:
required_tables &= tables
assert required_tables is not None
generator.required_tables = frozenset(required_tables)
generator.forbidden_tables = frozenset(forbidden_tables)

if found >= self.population_size:
return data

Expand Down Expand Up @@ -152,6 +196,8 @@ def __init__(self, variable_definitions, random_seed, today, population_size):

self.__column_values = {}
self.__reset_event_range()
self.required_tables = None
self.forbidden_tables = None

@property
def rnd(self):
Expand Down Expand Up @@ -287,8 +333,16 @@ def rows_for_practice_registrations(self, table_info):

def empty_rows(self, table_info):
# Generate a small handful of events for event-level tables
max_rows = 1 if table_info.has_one_row_per_patient else 16
row_count = self.rnd.randrange(max_rows + 1)
if self.forbidden_tables and table_info.name in self.forbidden_tables:
return []
if table_info.has_one_row_per_patient:
row_count = self.rnd.randint(0, 1)
else:
# Geometric distribution with parameter 0.2. Will average 4 (=1/0.2 - 1) events
# per patient.
row_count = math.floor(math.log(self.rnd.random()) / math.log(1 - 0.2))
if self.required_tables and table_info.name in self.required_tables:
row_count += 1
return [{} for _ in range(row_count)]

def populate_row(self, table_info, row):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/dummy_data_nextgen/test_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_dummy_measures_data_generator():

intervals = years(2).starting_on("2020-01-01")
measures = Measures()
measures.dummy_data_config.population_size = 200
measures.dummy_data_config.population_size = 300

measures.define_measure(
"foo_events_by_sex",
Expand Down
61 changes: 60 additions & 1 deletion tests/unit/dummy_data_nextgen/test_specific_datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from collections import Counter
from datetime import date
from unittest import mock

Expand All @@ -18,7 +19,11 @@
table,
table_from_rows,
)
from ehrql.tables.core import clinical_events, medications, patients
from ehrql.tables.core import (
clinical_events,
medications,
patients,
)


index_date = date(2022, 3, 1)
Expand Down Expand Up @@ -81,6 +86,37 @@ def test_queries_with_exact_one_shot_generation(patched_time, query):
assert len(patient_ids) == target_size


@mock.patch("ehrql.dummy_data_nextgen.generator.time")
@pytest.mark.parametrize(
"query",
[
clinical_events.exists_for_patient(),
~clinical_events.exists_for_patient(),
],
ids=pretty,
)
def test_queries_with_exact_two_shot_generation(patched_time, query):
"""For queries which we can't guarantee correct from the start
but we can reliably figure out enough in the first batch of results
that the second one is complete."""
dataset = create_dataset()

dataset.define_population(patients.exists_for_patient() & query)

target_size = 1000

variable_definitions = compile(dataset)
generator = DummyDataGenerator(variable_definitions, population_size=target_size)
generator.batch_size = target_size
generator.timeout = 10

# Configure `time.time()` so we timeout after two loop passes.
patched_time.time.side_effect = [0.0, 1.0, 20.0]
patient_ids = {row.patient_id for row in generator.get_results()}

assert len(patient_ids) == target_size


@st.composite
def birthday_range_query(draw):
# We generate a single date that we require to be valid for
Expand Down Expand Up @@ -315,3 +351,26 @@ def test_generates_events_starting_from_birthdate():

for row in generator.get_results():
assert row.after_dob


def test_distribution_of_booleans():
"""Ensures that the distribution of boolean properties depending on the existence
of an event is not too badly biased."""
dataset = create_dataset()

dataset.has_the_thing = clinical_events.where(
clinical_events.snomedct_code == "123456789"
).exists_for_patient()

dataset.define_population(patients.exists_for_patient())

target_size = 1000

variable_definitions = compile(dataset)
generator = DummyDataGenerator(variable_definitions, population_size=target_size)
generator.batch_size = target_size

property_counts = Counter(row.has_the_thing for row in generator.get_results())

assert property_counts[False] + property_counts[True] == target_size
assert 0.2 < property_counts[True] / target_size < 0.8

0 comments on commit 92fd4d6

Please sign in to comment.