Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert buggy in-memory engine change #2256

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions ehrql/query_engines/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def sort(self, sort_index):
def pick_at_index(self, ix):
return PatientTable(
{
name: col.pick_at_index(ix, name == "patient_id")
name: col.pick_at_index(ix)
for name, col in self.name_to_col.items()
if name != "row_id"
}
Expand Down Expand Up @@ -351,14 +351,19 @@ def sort(self, sort_index):
{p: rows.sort(sort_index[p]) for p, rows in self.patient_to_rows.items()}
)

def pick_at_index(self, ix, is_patient_id=False):
if is_patient_id:
# The patient_id column is special, and should always be a mapping
# from an id to itself. Rows.pick_at_index will return None if a
# patient has no rows in the column.
return PatientColumn({p: p for p in self.patient_to_rows})
def pick_at_index(self, ix):
# It is arguable that for a patient with no rows (which would occur if
# this EventColumn was derived by filtering another EventColumn), the
# patient should be present in the new PatientColumn, with value None.
#
# However, we have decided to instead omit the patient from the new
# PatientColumn.
return PatientColumn(
{p: rows.pick_at_index(ix) for p, rows in self.patient_to_rows.items()}
{
p: rows.pick_at_index(ix)
for p, rows in self.patient_to_rows.items()
if rows
}
)


Expand Down Expand Up @@ -430,10 +435,7 @@ def sort(self, sort_index):
def pick_at_index(self, ix):
"""Return element at given position."""

try:
k = list(self)[ix]
except IndexError:
return None
k = list(self)[ix]
return self[k]


Expand Down
20 changes: 20 additions & 0 deletions tests/integration/test_query_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,23 @@ def test_population_which_uses_combine_as_set_and_no_patient_frame(engine):
assert engine.extract_qm(variables) == [
{"patient_id": 1, "v": True},
]


def test_picking_row_doesnt_cause_filtered_rows_to_reappear(engine):
# Regression test for a bug we introduced in the in-memory engine
dataset = create_dataset()
dataset.define_population(events.exists_for_patient())

rows = events.where(events.i < 0).sort_by(events.i).first_for_patient()
dataset.has_row = rows.exists_for_patient()
dataset.row_count = rows.count_for_patient()

engine.populate(
{
events: [{"patient_id": 1, "i": 2}],
}
)

assert engine.extract(dataset) == [
{"patient_id": 1, "has_row": False, "row_count": 0},
]
1 change: 0 additions & 1 deletion tests/unit/query_engines/test_in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ def test_event_table_filter_then_pick_at_index():
--+-----+-----
1 | 102 | 112
2 | 203 | 211
3 | |
""",
)

Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_quiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
clinical_events,
medications,
patients,
practice_registrations,
)


Expand Down Expand Up @@ -159,6 +160,25 @@ def test_check_answer_dataset_column_has_missing_patients(engine):
assert msg == "Incorrect `age` value for patient 1: expected 49, got 50 instead."


@pytest.mark.parametrize(
"order, message",
[
([0, 1], "Missing patient(s): 7."),
([1, 0], "Found extra patient(s): 7."),
],
)
def test_check_answer_patient_series_has_missing_or_extra_patients(
engine, order, message
):
series = [
practice_registrations.for_patient_on("2013-12-01").practice_pseudo_id,
practice_registrations.for_patient_on("2014-01-01").practice_pseudo_id,
]
answer, expected = (series[i] for i in order)
msg = quiz.check_answer(engine=engine, answer=answer, expected=expected)
assert msg == message


def test_check_answer_patient_series_has_incorrect_value(engine):
msg = quiz.check_answer(
engine=engine,
Expand Down