Skip to content

Commit

Permalink
Render DateDifference elements with other related series
Browse files Browse the repository at this point in the history
DateDifference nodes are not themselves a Series. However, we
render them as the repr of their .days method, which IS a Series.
In order to determine whether a DateDifference is related to other
debugged elements, we need to compare it as days.

We also need to evaluate a DateDifference as days before turning it
related records with other Series, and we keep track of the indices
of the DateDiff elements so that we can show them as "N days" in the
tables as we do for individual formatted columns.
  • Loading branch information
rebkwok committed Nov 28, 2024
1 parent a61b4cd commit aa10f88
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
52 changes: 43 additions & 9 deletions ehrql/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def render(
):
elements = [element, *other_elements]

if hasattr(element, "__repr_related__") and elements_are_related_series(elements):
if (
len(elements) > 1
and hasattr(element, "__repr_related__")
and elements_are_related_series(elements)
):
element_reprs = [element.__repr_related__(*other_elements)]
else:
element_reprs = [repr(el) for el in elements]
Expand Down Expand Up @@ -99,9 +103,28 @@ def activate_debug_context(*, dummy_tables_path, render_function):
Dataset.__repr__ = lambda self: query_engine.evaluate_dataset(self)._render_(
render_function
)

def _get_columns_and_datediff_indices(*args):
# Evaluate each series. If its a DateDifference, we need to
# evaluate it after converting to days.
# Also return a set of the indicies of the datediff elements,
# so that we can display them as "N days"
columns = []
datediff_indices = set()
for i, series in enumerate(args, start=1):
if isinstance(series, DateDifference):
columns.append(query_engine.evaluate(series.days))
datediff_indices.add(i)
else:
columns.append(query_engine.evaluate(series))
return columns, datediff_indices

# Add additional method for displaying related series together
BaseSeries.__repr_related__ = lambda *args: render_function(
related_columns_to_records([query_engine.evaluate(series) for series in args])
related_columns_to_records(*_get_columns_and_datediff_indices(*args))
)
DateDifference.__repr_related__ = lambda *args: render_function(
related_columns_to_records(*_get_columns_and_datediff_indices(*args))
)

try:
Expand All @@ -114,41 +137,52 @@ def activate_debug_context(*, dummy_tables_path, render_function):
DateDifference.__repr__ = DateDifference__repr__
Dataset.__repr__ = Dataset__repr__
del BaseSeries.__repr_related__
del DateDifference.__repr_related__


def format_column(column, template):
return apply_function(template.format, column)


def elements_are_related_series(elements):
# We render a DateDifference in days. A DateDifference itself isn't a Series, so we need to convert it
# to days before we can compare it with other elements.
elements = [el.days if isinstance(el, DateDifference) else el for el in elements]

qm_nodes = [getattr(el, "_qm_node", None) for el in elements]
if not all(isinstance(node, qm.Series) for node in qm_nodes):
return False
domains = {qm.get_domain(node) for node in qm_nodes}
return len(domains) == 1


def related_columns_to_records(columns):
def related_columns_to_records(columns, datediff_indices=set()):
if isinstance(columns[0], PatientColumn):
return related_patient_columns_to_records(columns)
return related_patient_columns_to_records(columns, datediff_indices)
elif isinstance(columns[0], EventColumn):
return related_event_columns_to_records(columns)
return related_event_columns_to_records(columns, datediff_indices)
else:
assert False


def related_patient_columns_to_records(columns):
def related_patient_columns_to_records(columns, datediff_indices):
for patient_id in columns[0].patient_to_value.keys():
record = {"patient_id": patient_id}
for i, column in enumerate(columns, start=1):
record[f"series_{i}"] = column[patient_id]
if i in datediff_indices:
record[f"series_{i}"] = f"{column[patient_id]} days"
else:
record[f"series_{i}"] = column[patient_id]
yield record


def related_event_columns_to_records(columns):
def related_event_columns_to_records(columns, datediff_indices):
for patient_id, row in columns[0].patient_to_rows.items():
for row_id in row.keys():
record = {"patient_id": patient_id, "row_id": row_id}
for i, column in enumerate(columns, start=1):
record[f"series_{i}"] = column[patient_id][row_id]
if i in datediff_indices:
record[f"series_{i}"] = f"{column[patient_id][row_id]} days"
else:
record[f"series_{i}"] = column[patient_id][row_id]
yield record
16 changes: 16 additions & 0 deletions tests/unit/test_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,19 @@ def test_repr_date_difference(dummy_tables_path):
{"patient_id": 1, "row_id": 2, "value": "18262 days"},
{"patient_id": 2, "row_id": 3, "value": "9132 days"},
]


def test_repr_related_date_difference_series(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
render_function=lambda value: json.dumps(list(value), indent=4),
):
rendered = render(
events.date - patients.date_of_birth,
events.code,
)
assert json.loads(rendered) == [
{"patient_id": 1, "row_id": 1, "series_1": "14610 days", "series_2": "abc"},
{"patient_id": 1, "row_id": 2, "series_1": "18262 days", "series_2": "def"},
{"patient_id": 2, "row_id": 3, "series_1": "9132 days", "series_2": "abc"},
]

0 comments on commit aa10f88

Please sign in to comment.