Skip to content

Commit

Permalink
Refactor debug rendering
Browse files Browse the repository at this point in the history
The sandbox method of evaluating and rendering ehrql objects was via
monkey patching various repr functions  The initial debug implementation
re-used this method, and manually called `repr()` from within the show()
call.

Now that we don't have a sandbox, all rendering can be done explicitly
by the show() function, so there is no need for repr monkey patching.
This refactor switches the evaluate/render steps to be done explicitly
in show(), rather than going via repr.

However, the repr monkey patching implementation used a closure to
capture the supplied renderer and engine details in the monkeypatched
function. As `show()` is a user imported static function, it cannot use
a closure. Instead, we formalize the DEBUG_QUERY_ENGINE global into
a new DEBUG_CONTEXT global, which is set via `activate_debug_context`,
and then available to use inside `show()`. This DEBUG_CONTEXT stores the
engine instance and the renderer, and knows how to render and object
with those.
  • Loading branch information
bloodearnest committed Jan 14, 2025
1 parent dccb00a commit a5052f2
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 143 deletions.
138 changes: 62 additions & 76 deletions ehrql/debugger.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import contextlib
import inspect
import sys
import typing
from dataclasses import dataclass

from ehrql.query_engines.debug import DebugQueryEngine
from ehrql.query_engines.in_memory_database import (
EventColumn,
PatientColumn,
render_value,
)
from ehrql.query_language import BaseFrame, BaseSeries, Dataset, DateDifference
from ehrql.query_language import BaseSeries, Dataset, DateDifference
from ehrql.query_model import nodes as qm
from ehrql.renderers import truncate_table
from ehrql.utils.docs_utils import exclude_from_docs


DEBUG_QUERY_ENGINE = None
DEBUG_CONTEXT = None


@exclude_from_docs
Expand Down Expand Up @@ -54,26 +56,17 @@ def show(
label = f" {label}" if label else ""
print(f"Show line {line_no}:{label}", file=sys.stderr)

if DEBUG_QUERY_ENGINE is None:
if DEBUG_CONTEXT is None:
# We're not in debug mode so ignore
print(
" - show() ignored because we're not running in debug mode", file=sys.stderr
)
return

elements = [element, *other_elements]

is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0
is_multiple_series_same_domain = len(elements) > 1 and elements_are_related_series(
elements
)

# We throw an error unless show() is called with:
# 1. A single ehrql object (series, frame, dataset)
# 2. Multiple series drawn from the same domain
if is_single_ehrql_object | is_multiple_series_same_domain:
print(render(element, *other_elements, head=head, tail=tail), file=sys.stderr)
else:
try:
rendered = DEBUG_CONTEXT.render(element, *other_elements, head=head, tail=tail)
except TypeError:
# raise more helpful show() specific error
raise TypeError(
"\n\nshow() can be used in two ways. Either:\n"
" - call it with a single argument such as:\n"
Expand All @@ -89,73 +82,66 @@ def show(
"clinical_events.numeric_values)"
)


def render(
element,
*other_elements,
head: int | None = None,
tail: int | None = None,
):
if len(other_elements) == 0:
# single ehrql element so we just display it
element_repr = repr(element)
elif hasattr(element, "__repr_related__"):
# multiple ehrql series so we combine them
element_repr = element.__repr_related__(*other_elements)
else:
raise TypeError(
"When render() is called on multiple series, the first argument must have a __repr_related__ attribute."
)

if head or tail:
element_repr = truncate_table(element_repr, head, tail)
return element_repr
print(rendered, file=sys.stderr)


@contextlib.contextmanager
def activate_debug_context(*, dummy_tables_path, render_function):
global DEBUG_QUERY_ENGINE

# Record original methods
BaseFrame__repr__ = BaseFrame.__repr__
BaseSeries__repr__ = BaseSeries.__repr__
DateDifference__repr__ = DateDifference.__repr__
Dataset__repr__ = Dataset.__repr__

query_engine = DebugQueryEngine(dummy_tables_path)
DEBUG_QUERY_ENGINE = query_engine

def repr_ehrql(obj, template=None):
return query_engine.evaluate(obj)._render_(render_function)

def repr_related_ehrql(*series_args):
return render_function(
related_columns_to_records(
[query_engine.evaluate(series) for series in series_args]
global DEBUG_CONTEXT
DEBUG_CONTEXT = DebugContext(
query_engine=DebugQueryEngine(dummy_tables_path),
render_function=render_function,
)
try:
yield DEBUG_CONTEXT # yield the context for testing convenience
finally:
DEBUG_CONTEXT = None


@dataclass
class DebugContext:
"""Record the currently active query engine and render function."""

query_engine: DebugQueryEngine
render_function: typing.Callable

def render(
self,
element,
*other_elements,
head: int | None = None,
tail: int | None = None,
):
elements = [element, *other_elements]

is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0
is_multiple_series_same_domain = len(
elements
) > 1 and elements_are_related_series(elements)

# We throw an error unless called with:
# 1. A single ehrql object (series, frame, dataset)
# 2. Multiple series drawn from the same domain
if not (is_single_ehrql_object | is_multiple_series_same_domain):
raise TypeError("All elements must be in the same domain")

if len(other_elements) == 0:
# single ehrql element so we just display it
rendered = self.query_engine.evaluate(element)._render_(
self.render_function
)
else:
# multiple ehrql series so we combine them
rendered = self.render_function(
related_columns_to_records(
[self.query_engine.evaluate(element) for element in elements]
)
)
)

# Temporarily overwrite __repr__ methods to display contents
BaseFrame.__repr__ = repr_ehrql
BaseSeries.__repr__ = repr_ehrql
DateDifference.__repr__ = repr_ehrql
Dataset.__repr__ = repr_ehrql

# Add additional method for displaying related series together
BaseSeries.__repr_related__ = repr_related_ehrql
DateDifference.__repr_related__ = repr_related_ehrql
if head or tail:
rendered = truncate_table(rendered, head, tail)

try:
yield
finally:
DEBUG_QUERY_ENGINE = None
# Restore original methods
BaseFrame.__repr__ = BaseFrame__repr__
BaseSeries.__repr__ = BaseSeries__repr__
DateDifference.__repr__ = DateDifference__repr__
Dataset.__repr__ = Dataset__repr__
del BaseSeries.__repr_related__
del DateDifference.__repr_related__
return rendered


def elements_are_related_series(elements):
Expand Down
6 changes: 6 additions & 0 deletions ehrql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,12 @@ def load_definition_unsafe(


def load_module(module_path, user_args=()):
"""Load a module containing an ehrql dataset definition.
Renders the serialized query model as json to the callers stdout, and any
other output to stderr.
"""

# Taken from the official recipe for importing a module from a file path:
# https://docs.python.org/3.9/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_path.stem, module_path)
Expand Down
4 changes: 2 additions & 2 deletions ehrql/quiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def set_dummy_tables_path(self, path):
# itself, regardless of how the quiz script was invoked. This kind of global
# nastiness seems tolerable in the context of the quiz, and worth it for
# avoiding potential confusion.
if ehrql.debugger.DEBUG_QUERY_ENGINE is not None:
ehrql.debugger.DEBUG_QUERY_ENGINE.dsn = path
if ehrql.debugger.DEBUG_CONTEXT is not None:
ehrql.debugger.DEBUG_CONTEXT.query_engine.dsn = path

def __setitem__(self, index, question):
question.index = index
Expand Down
78 changes: 15 additions & 63 deletions tests/unit/test_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
activate_debug_context,
elements_are_related_series,
related_patient_columns_to_records,
render,
)
from ehrql.query_engines.in_memory_database import PatientColumn
from ehrql.tables import EventFrame, PatientFrame, Series, table
Expand All @@ -24,7 +23,7 @@ def date_serializer(obj):
def test_show(capsys):
expected_output = textwrap.dedent(
"""
Show line 32:
Show line 31:
"""
).strip()
Expand All @@ -37,7 +36,7 @@ def test_show(capsys):
def test_show_with_label(capsys):
expected_output = textwrap.dedent(
"""
Show line 45: Number
Show line 44: Number
"""
).strip()
Expand Down Expand Up @@ -81,48 +80,6 @@ def test_related_patient_columns_to_records_full_join():
assert r == r_expected


def test_render_formatted_table():
expected_output = textwrap.dedent(
"""
patient_id | value
------------------+------------------
1 | 101
2 | 201
"""
).strip()

c = PatientColumn.parse(
"""
1 | 101
2 | 201
"""
)
assert render(c).strip() == expected_output


def test_render_truncated_table():
expected_output = textwrap.dedent(
"""
patient_id | value
------------------+------------------
1 | 101
... | ...
4 | 401
"""
).strip()

c = PatientColumn.parse(
"""
1 | 101
2 | 201
3 | 301
4 | 401
"""
)

assert render(c, head=1, tail=1) == expected_output


@table
class patients(PatientFrame):
date_of_birth = Series(date)
Expand Down Expand Up @@ -218,8 +175,8 @@ def test_activate_debug_context(dummy_tables_path, expression, contents):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
render_function=lambda value: repr(list(value)),
):
assert repr(expression) == repr(contents)
) as ctx:
assert ctx.render(expression) == repr(contents)


@pytest.mark.parametrize(
Expand All @@ -243,8 +200,8 @@ def test_repr_related_patient_series(dummy_tables_path):
render_function=lambda value: json.dumps(
list(value), indent=4, default=date_serializer
),
):
rendered = render(
) as ctx:
rendered = ctx.render(
patients.date_of_birth,
patients.sex,
events.count_for_patient(),
Expand Down Expand Up @@ -274,8 +231,8 @@ def test_repr_related_event_series(dummy_tables_path):
render_function=lambda value: json.dumps(
list(value), indent=4, default=date_serializer
),
):
rendered = render(events.date, events.code, events.test_result)
) as ctx:
rendered = ctx.render(events.date, events.code, events.test_result)
assert json.loads(rendered) == [
{
"patient_id": 1,
Expand Down Expand Up @@ -305,8 +262,8 @@ def test_repr_date_difference(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
render_function=lambda value: json.dumps(list(value), indent=4),
):
rendered = render(patients.date_of_death - events.date)
) as ctx:
rendered = ctx.render(patients.date_of_death - events.date)
assert json.loads(rendered) == [
{"patient_id": 1, "row_id": 1, "value": ""},
{"patient_id": 1, "row_id": 2, "value": ""},
Expand All @@ -318,8 +275,8 @@ def test_repr_related_date_difference_patient_series(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
render_function=lambda value: json.dumps(list(value), indent=4),
):
rendered = render(
) as ctx:
rendered = ctx.render(
"2024-01-01" - patients.date_of_birth,
patients.sex,
)
Expand All @@ -333,8 +290,8 @@ def test_repr_related_date_difference_event_series(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
render_function=lambda value: json.dumps(list(value), indent=4),
):
rendered = render(
) as ctx:
rendered = ctx.render(
events.date - patients.date_of_birth,
events.code,
)
Expand Down Expand Up @@ -386,16 +343,11 @@ def test_show_does_not_raise_error_for_series_from_same_domain(
def test_show_not_run_outside_debug_context(capsys):
expected_output = textwrap.dedent(
"""
Show line 394:
Show line 351:
- show() ignored because we're not running in debug mode
"""
).strip()

show(patients.date_of_birth, patients.sex)
captured = capsys.readouterr()
assert captured.err.strip() == expected_output, captured.err


def test_render_multiple_without_repr_related_errors():
with pytest.raises(TypeError):
render("hello", "goodbye")
4 changes: 2 additions & 2 deletions tests/unit/test_quiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ def test_set_dummy_tables_path_in_debug_context():
):
questions = quiz.Questions()
questions.set_dummy_tables_path("bar")
assert debugger.DEBUG_QUERY_ENGINE.dsn.name == "bar"
assert debugger.DEBUG_CONTEXT.query_engine.dsn.name == "bar"
# This should be unset outside of the context manager
assert debugger.DEBUG_QUERY_ENGINE is None
assert debugger.DEBUG_CONTEXT is None


def test_hint(capfd):
Expand Down

0 comments on commit a5052f2

Please sign in to comment.