Skip to content

Commit

Permalink
Add a quiz module with a check_answer function (#2214)
Browse files Browse the repository at this point in the history
* Refactor in-memory-database reprs

* Support Markdown reprs for in-memory db objects

This is the method name supported by IPython and Marimo, among others.

* Create quiz module and corresponding tests

* Check dataset answers

* Improve type checking and fall-back error message

* Check patient series

* Check event frames and series

* Check patient/row IDs column-by-column and add test for fall-back message

* Add simple generative tests using hypothesis

* No-cover certain bits of code and provide justification

* Tiny formatting fix

* Fix repr methods from wonky rebase

* Use `ev_answer` and `ev_expected` instead of `ev_ans` and `ev_exp`

---------

Co-authored-by: David Evans <[email protected]>
Co-authored-by: Becky Smith <[email protected]>
  • Loading branch information
3 people authored Nov 18, 2024
1 parent 707fb3b commit 495f1e5
Show file tree
Hide file tree
Showing 4 changed files with 663 additions and 2 deletions.
30 changes: 28 additions & 2 deletions ehrql/query_engines/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def __repr__(self):
def _render_(self, render_fn):
return render_fn(self.to_records())

def _repr_markdown_(self):
return records_to_markdown(self.to_records())

def __getitem__(self, name):
return self.name_to_col[name]

Expand Down Expand Up @@ -178,6 +181,9 @@ def __repr__(self):
def _render_(self, render_fn):
return render_fn(self.to_records())

def _repr_markdown_(self):
return records_to_markdown(self.to_records())

def __getitem__(self, name):
return self.name_to_col[name]

Expand Down Expand Up @@ -253,12 +259,17 @@ def __repr__(self):
def _render_(self, render_fn):
return render_fn(self.to_records())

def _repr_markdown_(self):
return records_to_markdown(self.to_records())

def __getitem__(self, patient):
return self.patient_to_value.get(patient, self.default)

def to_records(self):
for p, v in sorted(self.patient_to_value.items()):
yield {"patient_id": p, "value": v}
return (
{"patient_id": p, "value": v}
for p, v in sorted(self.patient_to_value.items())
)

def patients(self):
return set(self.patient_to_value)
Expand Down Expand Up @@ -312,6 +323,9 @@ def __repr__(self):
def _render_(self, render_fn):
return render_fn(self.to_records())

def _repr_markdown_(self):
return records_to_markdown(self.to_records())

def __getitem__(self, patient):
return self.patient_to_rows.get(patient, Rows({}))

Expand Down Expand Up @@ -516,3 +530,15 @@ def parse_value(value):
def nulls_first_order(key):
# Usable as a key function to `sorted()` which sorts NULLs first
return (0 if key is None else 1, key)


def records_to_markdown(records):
lines = []
headers_written = False
for record in records:
if not headers_written:
lines.append(" | ".join(record.keys()))
lines.append(" | ".join("---" for _ in record.keys()))
headers_written = True
lines.append(" | ".join(map(str, record.values())))
return "\n".join(f"| {line.strip()} |" for line in lines)
268 changes: 268 additions & 0 deletions ehrql/quiz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
from typing import Any

from ehrql.query_engines.in_memory_database import (
EventColumn,
EventTable,
PatientColumn,
PatientTable,
)
from ehrql.query_engines.sandbox import EmptyDataset, SandboxQueryEngine
from ehrql.query_language import BaseFrame, BaseSeries, Dataset


def check_answer(
engine: SandboxQueryEngine, answer: Any, expected: Dataset | BaseFrame | BaseSeries
) -> str:
message = check_type(answer, expected)
if message:
return message

ev_answer = evaluate(engine, answer)
ev_expected = evaluate(engine, expected)

message = check_type(ev_answer, ev_expected)
if message:
return message

if ev_answer == ev_expected:
return "Correct!"

checks = [
check_dataset_not_empty,
check_dataset_columns,
check_patient_ids,
check_patient_table_values,
check_patient_column_values,
check_event_row_ids,
check_event_table_values,
check_event_column_values,
]

for check in checks:
message = check(ev_answer, ev_expected)
if message is None:
continue # to the next check
return message
return "\n".join(
[
"Incorrect answer.",
"Expected:",
str(ev_expected),
"Got:",
str(ev_answer),
]
)


def get_type_name(obj: Any) -> str:
# Simplify the type name for better error messages
if isinstance(obj, BaseFrame):
return "Table"
elif isinstance(obj, BaseSeries):
return "Series"
return type(obj).__name__


def check_type(answer: Any, expected: Any) -> str:
if not isinstance(answer, type(expected)):
answer_type_name = get_type_name(answer)
expected_type_name = get_type_name(expected)

if answer_type_name == "EmptyDataset":
return None # Return a different message for empty datasets
if answer_type_name != expected_type_name:
# Only return an error message if it is helpful
return f"Expected {expected_type_name}, got {answer_type_name} instead."
return None


def evaluate(
engine: SandboxQueryEngine, answer: Dataset | BaseFrame | BaseSeries
) -> Any:
if isinstance(answer, Dataset):
return engine.evaluate_dataset(answer)
return engine.evaluate(answer)


def check_dataset_not_empty(ev_answer: Any, ev_expected: Any) -> str | None:
if isinstance(ev_answer, EmptyDataset):
if isinstance(ev_expected, EmptyDataset):
return "Correct!" # Special case: Not an error
return "The dataset is empty."
return None


def check_dataset_columns(ev_answer: Any, ev_expected: Any) -> str | None:
# Named so because we do not expect PatientTables from Frames to have varying columns
if isinstance(ev_expected, PatientTable):
return _check_missing_extra(
ev_answer, ev_expected, "column", getter=lambda t: t.name_to_col.keys()
)
return None


def check_patient_ids(ev_answer: Any, ev_expected: Any) -> str | None:
def check(_ev_answer, _ev_expected) -> str | None:
return _check_missing_extra(
_ev_answer,
_ev_expected,
"patient",
getter=lambda c: c.patients(),
)

if isinstance(ev_expected, PatientColumn):
return check(ev_answer, ev_expected)
if isinstance(ev_expected, PatientTable):
return _check_table_then_columns_one_by_one(
ev_answer,
ev_expected,
check,
column_names=list(ev_expected.name_to_col.keys() - {"patient_id"}),
)
return None


def check_patient_table_values(ev_answer: Any, ev_expected: Any) -> str | None:
if isinstance(ev_expected, PatientTable):
return _check_columns_one_by_one(
ev_answer,
ev_expected,
check_patient_column_values,
column_names=list(ev_expected.name_to_col.keys() - {"patient_id"}),
# Patient ID handled separately
)
return None


def check_patient_column_values(
ev_answer, ev_expected, column_name: str | None = None
) -> str | None:
if isinstance(ev_expected, PatientColumn):
column_name = f" `{column_name}` " if column_name else " "
incorrect = sorted(
ev_answer.patient_to_value.items() - ev_expected.patient_to_value.items()
)
# Last check for Patient Frames/Series; Expect incorrect to be non-empty
# Only show the first incorrect value
for k, v in incorrect: # pragma: no branch
return f"Incorrect{column_name}value for patient {k}: expected {str(ev_expected[k])}, got {str(v)} instead."
return None


def check_event_row_ids(ev_answer: Any, ev_expected: Any) -> str | None:
def check(_ev_answer, _ev_expected) -> str | None:
return _check_missing_extra(
_ev_answer,
_ev_expected,
"row",
getter=lambda c: set(row["row_id"] for row in c.to_records()),
)

if isinstance(ev_expected, EventColumn):
return check(ev_answer, ev_expected)
if isinstance(ev_expected, EventTable):
return _check_table_then_columns_one_by_one(
ev_answer,
ev_expected,
check,
column_names=list(
ev_expected.name_to_col.keys() - {"patient_id", "row_id"}
),
)
return None


# Cannot find a wrong answer that triggers this, but a wild user answer might
# So still include this check in the list but pragma: no cover it
def check_event_table_values(
ev_answer: Any, ev_expected: Any
) -> str | None: # pragma: no cover
if isinstance(ev_expected, EventTable):
return _check_columns_one_by_one(
ev_answer,
ev_expected,
check_event_column_values,
column_names=list(
ev_expected.name_to_col.keys() - {"patient_id", "row_id"}
),
# Patient ID and Row ID handled separately
)
return None


def check_event_column_values(
ev_answer, ev_expected, column_name: str | None = None
) -> str | None:
if isinstance(ev_expected, EventColumn):
column_name = f" `{column_name}` " if column_name else " "
records_ans = set(tuple(rec.values()) for rec in ev_answer.to_records())
records_exp = set(tuple(rec.values()) for rec in ev_expected.to_records())
incorrect = sorted(records_ans - records_exp)
# Last check for Event Frames/Series; Expect incorrect to be non-empty
# Only show the first incorrect value
for p, r, v in incorrect: # pragma: no branch
return f"Incorrect{column_name}value for patient {p}, row {r}: expected {str(ev_expected[p][r])}, got {str(v)} instead."
return None


# Utils functions
def get_items_missing_extra(
set_answer: set,
set_expected: set,
item_name: str,
) -> tuple[str | None, str | None]:
missing = list(map(str, sorted(set_expected - set_answer)))
missing = f"Missing {item_name}(s): {', '.join(missing)}." if missing else None
extra = list(map(str, sorted(set_answer - set_expected)))
extra = f"Found extra {item_name}(s): {', '.join(extra)}." if extra else None
return missing, extra


def _check_missing_extra(
ev_answer: Any,
ev_expected: Any,
item_name: str,
getter: callable,
) -> str | None:
missing_columns, extra_columns = get_items_missing_extra(
getter(ev_answer), getter(ev_expected), item_name
)
if missing_columns or extra_columns:
return "\n".join(filter(None, [missing_columns, extra_columns]))
return None


def _check_columns_one_by_one(
ev_answer: Any,
ev_expected: Any,
check_column: callable,
column_names: list[str],
) -> str | None:
for name in column_names:
msg = check_column(ev_answer[name], ev_expected[name], name)
if msg is None:
continue
return msg


def _check_table_then_columns_one_by_one(
ev_answer: Any,
ev_expected: Any,
check: callable,
column_names: list[str],
):
def check_column(col_ans, col_exp, column_name: str | None = None) -> str | None:
msg = check(col_ans, col_exp)
if msg:
return f"Column `{column_name}`:\n" + msg
return msg

msg_table = check(ev_answer, ev_expected)
if msg_table:
return msg_table
return _check_columns_one_by_one(
ev_answer,
ev_expected,
check_column,
column_names=column_names,
)
Loading

0 comments on commit 495f1e5

Please sign in to comment.