Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor debug rendering #2356

Merged
merged 4 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 62 additions & 78 deletions ehrql/debugger.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import contextlib
import inspect
import sys
import typing
from dataclasses import dataclass

from ehrql.query_engines.debug import DebugQueryEngine
from ehrql.query_engines.in_memory_database import (
EventColumn,
PatientColumn,
render_value,
)
from ehrql.query_language import BaseFrame, BaseSeries, Dataset, DateDifference
from ehrql.query_language import BaseSeries, Dataset, DateDifference
from ehrql.query_model import nodes as qm
from ehrql.renderers import truncate_table
from ehrql.utils.docs_utils import exclude_from_docs


DEBUG_QUERY_ENGINE = None
DEBUG_CONTEXT = None


@exclude_from_docs
Expand Down Expand Up @@ -54,26 +55,19 @@ def show(
label = f" {label}" if label else ""
print(f"Show line {line_no}:{label}", file=sys.stderr)

if DEBUG_QUERY_ENGINE is None:
if DEBUG_CONTEXT is None:
# We're not in debug mode so ignore
print(
" - show() ignored because we're not running in debug mode", file=sys.stderr
)
return

elements = [element, *other_elements]

is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0
is_multiple_series_same_domain = len(elements) > 1 and elements_are_related_series(
elements
)

# We throw an error unless show() is called with:
# 1. A single ehrql object (series, frame, dataset)
# 2. Multiple series drawn from the same domain
if is_single_ehrql_object | is_multiple_series_same_domain:
print(render(element, *other_elements, head=head, tail=tail), file=sys.stderr)
else:
try:
rendered = DEBUG_CONTEXT.render(
element, *other_elements, head=head or 0, tail=tail or 0
)
except TypeError:
# raise more helpful show() specific error
raise TypeError(
"\n\nshow() can be used in two ways. Either:\n"
" - call it with a single argument such as:\n"
Expand All @@ -89,73 +83,63 @@ def show(
"clinical_events.numeric_values)"
)


def render(
element,
*other_elements,
head: int | None = None,
tail: int | None = None,
):
if len(other_elements) == 0:
# single ehrql element so we just display it
element_repr = repr(element)
elif hasattr(element, "__repr_related__"):
# multiple ehrql series so we combine them
element_repr = element.__repr_related__(*other_elements)
else:
raise TypeError(
"When render() is called on multiple series, the first argument must have a __repr_related__ attribute."
)

if head or tail:
element_repr = truncate_table(element_repr, head, tail)
return element_repr
print(rendered, file=sys.stderr)


@contextlib.contextmanager
def activate_debug_context(*, dummy_tables_path, render_function):
global DEBUG_QUERY_ENGINE

# Record original methods
BaseFrame__repr__ = BaseFrame.__repr__
BaseSeries__repr__ = BaseSeries.__repr__
DateDifference__repr__ = DateDifference.__repr__
Dataset__repr__ = Dataset.__repr__

query_engine = DebugQueryEngine(dummy_tables_path)
DEBUG_QUERY_ENGINE = query_engine

def repr_ehrql(obj, template=None):
return query_engine.evaluate(obj)._render_(render_function)

def repr_related_ehrql(*series_args):
return render_function(
related_columns_to_records(
[query_engine.evaluate(series) for series in series_args]
)
)

# Temporarily overwrite __repr__ methods to display contents
BaseFrame.__repr__ = repr_ehrql
BaseSeries.__repr__ = repr_ehrql
DateDifference.__repr__ = repr_ehrql
Dataset.__repr__ = repr_ehrql

# Add additional method for displaying related series together
BaseSeries.__repr_related__ = repr_related_ehrql
DateDifference.__repr_related__ = repr_related_ehrql

global DEBUG_CONTEXT
DEBUG_CONTEXT = DebugContext(
query_engine=DebugQueryEngine(dummy_tables_path),
render_function=render_function,
)
try:
yield
yield DEBUG_CONTEXT # yield the context for testing convenience
finally:
DEBUG_QUERY_ENGINE = None
# Restore original methods
BaseFrame.__repr__ = BaseFrame__repr__
BaseSeries.__repr__ = BaseSeries__repr__
DateDifference.__repr__ = DateDifference__repr__
Dataset.__repr__ = Dataset__repr__
del BaseSeries.__repr_related__
del DateDifference.__repr_related__
DEBUG_CONTEXT = None


@dataclass
class DebugContext:
"""Record the currently active query engine and render function."""

query_engine: DebugQueryEngine
render_function: typing.Callable

def render(
self,
element,
*other_elements,
head: int = 0,
tail: int = 0,
):
elements = [element, *other_elements]

is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0
is_multiple_series_same_domain = len(
elements
) > 1 and elements_are_related_series(elements)

# We throw an error unless called with:
# 1. A single ehrql object (series, frame, dataset)
# 2. Multiple series drawn from the same domain
if not (is_single_ehrql_object | is_multiple_series_same_domain):
raise TypeError("All elements must be in the same domain")

if len(other_elements) == 0:
# single ehrql element so we just display it
evaluated = list(
self.query_engine.evaluate(element).to_records(convert_null=True)
)
else:
# multiple ehrql series so we combine them
evaluated = list(
related_columns_to_records(
[self.query_engine.evaluate(element) for element in elements]
)
)

return self.render_function(evaluated, head, tail)


def elements_are_related_series(elements):
Expand Down
6 changes: 6 additions & 0 deletions ehrql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,12 @@ def load_definition_unsafe(


def load_module(module_path, user_args=()):
"""Load a module containing an ehrql dataset definition.
Renders the serialized query model as json to the callers stdout, and any
other output to stderr.
"""

# Taken from the official recipe for importing a module from a file path:
# https://docs.python.org/3.9/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_path.stem, module_path)
Expand Down
4 changes: 2 additions & 2 deletions ehrql/query_engines/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ class EmptyDataset:
"""This class exists to render something nice when a user tries to inspect a dataset
with no columns with debugger.show()."""

def _render_(self, render_fn):
return render_fn([{"patient_id": ""}])
def to_records(self, convert_null=False):
return [{"patient_id": ""}]
8 changes: 4 additions & 4 deletions ehrql/query_engines/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __repr__(self):
return self._render_(DISPLAY_RENDERERS["ascii"])

def _render_(self, render_fn):
return render_fn(self.to_records(convert_null=True))
return render_fn(list(self.to_records(convert_null=True)))

def __getitem__(self, name):
return self.name_to_col[name]
Expand Down Expand Up @@ -179,7 +179,7 @@ def __repr__(self):
return self._render_(DISPLAY_RENDERERS["ascii"])

def _render_(self, render_fn):
return render_fn(self.to_records(convert_null=True))
return render_fn(list(self.to_records(convert_null=True)))

def __getitem__(self, name):
return self.name_to_col[name]
Expand Down Expand Up @@ -257,7 +257,7 @@ def __repr__(self):
return self._render_(DISPLAY_RENDERERS["ascii"])

def _render_(self, render_fn):
return render_fn(self.to_records(convert_null=True))
return render_fn(list(self.to_records(convert_null=True)))

def __getitem__(self, patient):
return self.patient_to_value.get(patient, self.default)
Expand Down Expand Up @@ -318,7 +318,7 @@ def __repr__(self):
return self._render_(DISPLAY_RENDERERS["ascii"])

def _render_(self, render_fn):
return render_fn(self.to_records(convert_null=True))
return render_fn(list(self.to_records(convert_null=True)))

def __getitem__(self, patient):
return self.patient_to_rows.get(patient, Rows({}))
Expand Down
4 changes: 2 additions & 2 deletions ehrql/quiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def set_dummy_tables_path(self, path):
# itself, regardless of how the quiz script was invoked. This kind of global
# nastiness seems tolerable in the context of the quiz, and worth it for
# avoiding potential confusion.
if ehrql.debugger.DEBUG_QUERY_ENGINE is not None:
ehrql.debugger.DEBUG_QUERY_ENGINE.dsn = path
if ehrql.debugger.DEBUG_CONTEXT is not None:
ehrql.debugger.DEBUG_CONTEXT.query_engine.dsn = path

def __setitem__(self, index, question):
question.index = index
Expand Down
Loading
Loading