diff --git a/ehrql/debugger.py b/ehrql/debugger.py index 03b950fb4..cf457d095 100644 --- a/ehrql/debugger.py +++ b/ehrql/debugger.py @@ -1,6 +1,8 @@ import contextlib import inspect import sys +import typing +from dataclasses import dataclass from ehrql.query_engines.debug import DebugQueryEngine from ehrql.query_engines.in_memory_database import ( @@ -8,13 +10,13 @@ PatientColumn, render_value, ) -from ehrql.query_language import BaseFrame, BaseSeries, Dataset, DateDifference +from ehrql.query_language import BaseSeries, Dataset, DateDifference from ehrql.query_model import nodes as qm from ehrql.renderers import truncate_table from ehrql.utils.docs_utils import exclude_from_docs -DEBUG_QUERY_ENGINE = None +DEBUG_CONTEXT = None @exclude_from_docs @@ -54,26 +56,17 @@ def show( label = f" {label}" if label else "" print(f"Show line {line_no}:{label}", file=sys.stderr) - if DEBUG_QUERY_ENGINE is None: + if DEBUG_CONTEXT is None: # We're not in debug mode so ignore print( " - show() ignored because we're not running in debug mode", file=sys.stderr ) return - elements = [element, *other_elements] - - is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0 - is_multiple_series_same_domain = len(elements) > 1 and elements_are_related_series( - elements - ) - - # We throw an error unless show() is called with: - # 1. A single ehrql object (series, frame, dataset) - # 2. Multiple series drawn from the same domain - if is_single_ehrql_object | is_multiple_series_same_domain: - print(render(element, *other_elements, head=head, tail=tail), file=sys.stderr) - else: + try: + rendered = DEBUG_CONTEXT.render(element, *other_elements, head=head, tail=tail) + except TypeError: + # raise more helpful show() specific error raise TypeError( "\n\nshow() can be used in two ways. Either:\n" " - call it with a single argument such as:\n" @@ -89,73 +82,66 @@ def show( "clinical_events.numeric_values)" ) - -def render( - element, - *other_elements, - head: int | None = None, - tail: int | None = None, -): - if len(other_elements) == 0: - # single ehrql element so we just display it - element_repr = repr(element) - elif hasattr(element, "__repr_related__"): - # multiple ehrql series so we combine them - element_repr = element.__repr_related__(*other_elements) - else: - raise TypeError( - "When render() is called on multiple series, the first argument must have a __repr_related__ attribute." - ) - - if head or tail: - element_repr = truncate_table(element_repr, head, tail) - return element_repr + print(rendered, file=sys.stderr) @contextlib.contextmanager def activate_debug_context(*, dummy_tables_path, render_function): - global DEBUG_QUERY_ENGINE - - # Record original methods - BaseFrame__repr__ = BaseFrame.__repr__ - BaseSeries__repr__ = BaseSeries.__repr__ - DateDifference__repr__ = DateDifference.__repr__ - Dataset__repr__ = Dataset.__repr__ - - query_engine = DebugQueryEngine(dummy_tables_path) - DEBUG_QUERY_ENGINE = query_engine - - def repr_ehrql(obj, template=None): - return query_engine.evaluate(obj)._render_(render_function) - - def repr_related_ehrql(*series_args): - return render_function( - related_columns_to_records( - [query_engine.evaluate(series) for series in series_args] + global DEBUG_CONTEXT + DEBUG_CONTEXT = DebugContext( + query_engine=DebugQueryEngine(dummy_tables_path), + render_function=render_function, + ) + try: + yield DEBUG_CONTEXT # yield the context for testing convenience + finally: + DEBUG_CONTEXT = None + + +@dataclass +class DebugContext: + """Record the currently active query engine and render function.""" + + query_engine: DebugQueryEngine + render_function: typing.Callable + + def render( + self, + element, + *other_elements, + head: int | None = None, + tail: int | None = None, + ): + elements = [element, *other_elements] + + is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0 + is_multiple_series_same_domain = len( + elements + ) > 1 and elements_are_related_series(elements) + + # We throw an error unless called with: + # 1. A single ehrql object (series, frame, dataset) + # 2. Multiple series drawn from the same domain + if not (is_single_ehrql_object | is_multiple_series_same_domain): + raise TypeError("All elements must be in the same domain") + + if len(other_elements) == 0: + # single ehrql element so we just display it + rendered = self.query_engine.evaluate(element)._render_( + self.render_function + ) + else: + # multiple ehrql series so we combine them + rendered = self.render_function( + related_columns_to_records( + [self.query_engine.evaluate(element) for element in elements] + ) ) - ) - - # Temporarily overwrite __repr__ methods to display contents - BaseFrame.__repr__ = repr_ehrql - BaseSeries.__repr__ = repr_ehrql - DateDifference.__repr__ = repr_ehrql - Dataset.__repr__ = repr_ehrql - # Add additional method for displaying related series together - BaseSeries.__repr_related__ = repr_related_ehrql - DateDifference.__repr_related__ = repr_related_ehrql + if head or tail: + rendered = truncate_table(rendered, head, tail) - try: - yield - finally: - DEBUG_QUERY_ENGINE = None - # Restore original methods - BaseFrame.__repr__ = BaseFrame__repr__ - BaseSeries.__repr__ = BaseSeries__repr__ - DateDifference.__repr__ = DateDifference__repr__ - Dataset.__repr__ = Dataset__repr__ - del BaseSeries.__repr_related__ - del DateDifference.__repr_related__ + return rendered def elements_are_related_series(elements): diff --git a/ehrql/loaders.py b/ehrql/loaders.py index 1ea4536e4..4c4cea2bb 100644 --- a/ehrql/loaders.py +++ b/ehrql/loaders.py @@ -324,6 +324,12 @@ def load_definition_unsafe( def load_module(module_path, user_args=()): + """Load a module containing an ehrql dataset definition. + + Renders the serialized query model as json to the callers stdout, and any + other output to stderr. + """ + # Taken from the official recipe for importing a module from a file path: # https://docs.python.org/3.9/library/importlib.html#importing-a-source-file-directly spec = importlib.util.spec_from_file_location(module_path.stem, module_path) diff --git a/ehrql/quiz.py b/ehrql/quiz.py index beb9355cb..5a5a2fa60 100644 --- a/ehrql/quiz.py +++ b/ehrql/quiz.py @@ -276,8 +276,8 @@ def set_dummy_tables_path(self, path): # itself, regardless of how the quiz script was invoked. This kind of global # nastiness seems tolerable in the context of the quiz, and worth it for # avoiding potential confusion. - if ehrql.debugger.DEBUG_QUERY_ENGINE is not None: - ehrql.debugger.DEBUG_QUERY_ENGINE.dsn = path + if ehrql.debugger.DEBUG_CONTEXT is not None: + ehrql.debugger.DEBUG_CONTEXT.query_engine.dsn = path def __setitem__(self, index, question): question.index = index diff --git a/tests/unit/test_debugger.py b/tests/unit/test_debugger.py index 2a3e97496..d4d28af5e 100644 --- a/tests/unit/test_debugger.py +++ b/tests/unit/test_debugger.py @@ -9,7 +9,6 @@ activate_debug_context, elements_are_related_series, related_patient_columns_to_records, - render, ) from ehrql.query_engines.in_memory_database import PatientColumn from ehrql.tables import EventFrame, PatientFrame, Series, table @@ -24,7 +23,7 @@ def date_serializer(obj): def test_show(capsys): expected_output = textwrap.dedent( """ - Show line 32: + Show line 31: """ ).strip() @@ -37,7 +36,7 @@ def test_show(capsys): def test_show_with_label(capsys): expected_output = textwrap.dedent( """ - Show line 45: Number + Show line 44: Number """ ).strip() @@ -81,48 +80,6 @@ def test_related_patient_columns_to_records_full_join(): assert r == r_expected -def test_render_formatted_table(): - expected_output = textwrap.dedent( - """ - patient_id | value - ------------------+------------------ - 1 | 101 - 2 | 201 - """ - ).strip() - - c = PatientColumn.parse( - """ - 1 | 101 - 2 | 201 - """ - ) - assert render(c).strip() == expected_output - - -def test_render_truncated_table(): - expected_output = textwrap.dedent( - """ - patient_id | value - ------------------+------------------ - 1 | 101 - ... | ... - 4 | 401 - """ - ).strip() - - c = PatientColumn.parse( - """ - 1 | 101 - 2 | 201 - 3 | 301 - 4 | 401 - """ - ) - - assert render(c, head=1, tail=1) == expected_output - - @table class patients(PatientFrame): date_of_birth = Series(date) @@ -218,8 +175,8 @@ def test_activate_debug_context(dummy_tables_path, expression, contents): with activate_debug_context( dummy_tables_path=dummy_tables_path, render_function=lambda value: repr(list(value)), - ): - assert repr(expression) == repr(contents) + ) as ctx: + assert ctx.render(expression) == repr(contents) @pytest.mark.parametrize( @@ -243,8 +200,8 @@ def test_repr_related_patient_series(dummy_tables_path): render_function=lambda value: json.dumps( list(value), indent=4, default=date_serializer ), - ): - rendered = render( + ) as ctx: + rendered = ctx.render( patients.date_of_birth, patients.sex, events.count_for_patient(), @@ -274,8 +231,8 @@ def test_repr_related_event_series(dummy_tables_path): render_function=lambda value: json.dumps( list(value), indent=4, default=date_serializer ), - ): - rendered = render(events.date, events.code, events.test_result) + ) as ctx: + rendered = ctx.render(events.date, events.code, events.test_result) assert json.loads(rendered) == [ { "patient_id": 1, @@ -305,8 +262,8 @@ def test_repr_date_difference(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, render_function=lambda value: json.dumps(list(value), indent=4), - ): - rendered = render(patients.date_of_death - events.date) + ) as ctx: + rendered = ctx.render(patients.date_of_death - events.date) assert json.loads(rendered) == [ {"patient_id": 1, "row_id": 1, "value": ""}, {"patient_id": 1, "row_id": 2, "value": ""}, @@ -318,8 +275,8 @@ def test_repr_related_date_difference_patient_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, render_function=lambda value: json.dumps(list(value), indent=4), - ): - rendered = render( + ) as ctx: + rendered = ctx.render( "2024-01-01" - patients.date_of_birth, patients.sex, ) @@ -333,8 +290,8 @@ def test_repr_related_date_difference_event_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, render_function=lambda value: json.dumps(list(value), indent=4), - ): - rendered = render( + ) as ctx: + rendered = ctx.render( events.date - patients.date_of_birth, events.code, ) @@ -386,7 +343,7 @@ def test_show_does_not_raise_error_for_series_from_same_domain( def test_show_not_run_outside_debug_context(capsys): expected_output = textwrap.dedent( """ - Show line 394: + Show line 351: - show() ignored because we're not running in debug mode """ ).strip() @@ -394,8 +351,3 @@ def test_show_not_run_outside_debug_context(capsys): show(patients.date_of_birth, patients.sex) captured = capsys.readouterr() assert captured.err.strip() == expected_output, captured.err - - -def test_render_multiple_without_repr_related_errors(): - with pytest.raises(TypeError): - render("hello", "goodbye") diff --git a/tests/unit/test_quiz.py b/tests/unit/test_quiz.py index 7989be975..6517757c3 100644 --- a/tests/unit/test_quiz.py +++ b/tests/unit/test_quiz.py @@ -339,9 +339,9 @@ def test_set_dummy_tables_path_in_debug_context(): ): questions = quiz.Questions() questions.set_dummy_tables_path("bar") - assert debugger.DEBUG_QUERY_ENGINE.dsn.name == "bar" + assert debugger.DEBUG_CONTEXT.query_engine.dsn.name == "bar" # This should be unset outside of the context manager - assert debugger.DEBUG_QUERY_ENGINE is None + assert debugger.DEBUG_CONTEXT is None def test_hint(capfd):