diff --git a/ehrql/query_engines/in_memory_database.py b/ehrql/query_engines/in_memory_database.py index f2484fcb1..ce5e9feb9 100644 --- a/ehrql/query_engines/in_memory_database.py +++ b/ehrql/query_engines/in_memory_database.py @@ -208,7 +208,7 @@ def sort(self, sort_index): def pick_at_index(self, ix): return PatientTable( { - name: col.pick_at_index(ix, name == "patient_id") + name: col.pick_at_index(ix) for name, col in self.name_to_col.items() if name != "row_id" } @@ -351,14 +351,19 @@ def sort(self, sort_index): {p: rows.sort(sort_index[p]) for p, rows in self.patient_to_rows.items()} ) - def pick_at_index(self, ix, is_patient_id=False): - if is_patient_id: - # The patient_id column is special, and should always be a mapping - # from an id to itself. Rows.pick_at_index will return None if a - # patient has no rows in the column. - return PatientColumn({p: p for p in self.patient_to_rows}) + def pick_at_index(self, ix): + # It is arguable that for a patient with no rows (which would occur if + # this EventColumn was derived by filtering another EventColumn), the + # patient should be present in the new PatientColumn, with value None. + # + # However, we have decided to instead omit the patient from the new + # PatientColumn. return PatientColumn( - {p: rows.pick_at_index(ix) for p, rows in self.patient_to_rows.items()} + { + p: rows.pick_at_index(ix) + for p, rows in self.patient_to_rows.items() + if rows + } ) @@ -430,10 +435,7 @@ def sort(self, sort_index): def pick_at_index(self, ix): """Return element at given position.""" - try: - k = list(self)[ix] - except IndexError: - return None + k = list(self)[ix] return self[k] diff --git a/tests/integration/test_query_engines.py b/tests/integration/test_query_engines.py index a04497d0a..b26338d9f 100644 --- a/tests/integration/test_query_engines.py +++ b/tests/integration/test_query_engines.py @@ -350,3 +350,23 @@ def test_population_which_uses_combine_as_set_and_no_patient_frame(engine): assert engine.extract_qm(variables) == [ {"patient_id": 1, "v": True}, ] + + +def test_picking_row_doesnt_cause_filtered_rows_to_reappear(engine): + # Regression test for a bug we introduced in the in-memory engine + dataset = create_dataset() + dataset.define_population(events.exists_for_patient()) + + rows = events.where(events.i < 0).sort_by(events.i).first_for_patient() + dataset.has_row = rows.exists_for_patient() + dataset.row_count = rows.count_for_patient() + + engine.populate( + { + events: [{"patient_id": 1, "i": 2}], + } + ) + + assert engine.extract(dataset) == [ + {"patient_id": 1, "has_row": False, "row_count": 0}, + ] diff --git a/tests/unit/query_engines/test_in_memory_database.py b/tests/unit/query_engines/test_in_memory_database.py index c3a592143..cf91646c7 100644 --- a/tests/unit/query_engines/test_in_memory_database.py +++ b/tests/unit/query_engines/test_in_memory_database.py @@ -395,7 +395,6 @@ def test_event_table_filter_then_pick_at_index(): --+-----+----- 1 | 102 | 112 2 | 203 | 211 - 3 | | """, ) diff --git a/tests/unit/test_quiz.py b/tests/unit/test_quiz.py index 347b56560..48e2901bb 100644 --- a/tests/unit/test_quiz.py +++ b/tests/unit/test_quiz.py @@ -12,6 +12,7 @@ clinical_events, medications, patients, + practice_registrations, ) @@ -159,6 +160,25 @@ def test_check_answer_dataset_column_has_missing_patients(engine): assert msg == "Incorrect `age` value for patient 1: expected 49, got 50 instead." +@pytest.mark.parametrize( + "order, message", + [ + ([0, 1], "Missing patient(s): 7."), + ([1, 0], "Found extra patient(s): 7."), + ], +) +def test_check_answer_patient_series_has_missing_or_extra_patients( + engine, order, message +): + series = [ + practice_registrations.for_patient_on("2013-12-01").practice_pseudo_id, + practice_registrations.for_patient_on("2014-01-01").practice_pseudo_id, + ] + answer, expected = (series[i] for i in order) + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert msg == message + + def test_check_answer_patient_series_has_incorrect_value(engine): msg = quiz.check_answer( engine=engine,