From 86b6deff44179b2a85a21ecc91cb92d1996cfef0 Mon Sep 17 00:00:00 2001 From: bloodearnest Date: Tue, 14 Jan 2025 12:46:06 +0000 Subject: [PATCH 1/4] Refactor debug rendering The sandbox method of evaluating and rendering ehrql objects was via monkey patching various repr functions The initial debug implementation re-used this method, and manually called `repr()` from within the show() call. Now that we don't have a sandbox, all rendering can be done explicitly by the show() function, so there is no need for repr monkey patching. This refactor switches the evaluate/render steps to be done explicitly in show(), rather than going via repr. However, the repr monkey patching implementation used a closure to capture the supplied renderer and engine details in the monkeypatched function. As `show()` is a user imported static function, it cannot use a closure. Instead, we formalize the DEBUG_QUERY_ENGINE global into a new DEBUG_CONTEXT global, which is set via `activate_debug_context`, and then available to use inside `show()`. This DEBUG_CONTEXT stores the engine instance and the renderer, and knows how to render and object with those. --- ehrql/debugger.py | 138 ++++++++++++++++-------------------- ehrql/loaders.py | 6 ++ ehrql/quiz.py | 4 +- tests/unit/test_debugger.py | 78 ++++---------------- tests/unit/test_quiz.py | 4 +- 5 files changed, 87 insertions(+), 143 deletions(-) diff --git a/ehrql/debugger.py b/ehrql/debugger.py index 03b950fb4..cf457d095 100644 --- a/ehrql/debugger.py +++ b/ehrql/debugger.py @@ -1,6 +1,8 @@ import contextlib import inspect import sys +import typing +from dataclasses import dataclass from ehrql.query_engines.debug import DebugQueryEngine from ehrql.query_engines.in_memory_database import ( @@ -8,13 +10,13 @@ PatientColumn, render_value, ) -from ehrql.query_language import BaseFrame, BaseSeries, Dataset, DateDifference +from ehrql.query_language import BaseSeries, Dataset, DateDifference from ehrql.query_model import nodes as qm from ehrql.renderers import truncate_table from ehrql.utils.docs_utils import exclude_from_docs -DEBUG_QUERY_ENGINE = None +DEBUG_CONTEXT = None @exclude_from_docs @@ -54,26 +56,17 @@ def show( label = f" {label}" if label else "" print(f"Show line {line_no}:{label}", file=sys.stderr) - if DEBUG_QUERY_ENGINE is None: + if DEBUG_CONTEXT is None: # We're not in debug mode so ignore print( " - show() ignored because we're not running in debug mode", file=sys.stderr ) return - elements = [element, *other_elements] - - is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0 - is_multiple_series_same_domain = len(elements) > 1 and elements_are_related_series( - elements - ) - - # We throw an error unless show() is called with: - # 1. A single ehrql object (series, frame, dataset) - # 2. Multiple series drawn from the same domain - if is_single_ehrql_object | is_multiple_series_same_domain: - print(render(element, *other_elements, head=head, tail=tail), file=sys.stderr) - else: + try: + rendered = DEBUG_CONTEXT.render(element, *other_elements, head=head, tail=tail) + except TypeError: + # raise more helpful show() specific error raise TypeError( "\n\nshow() can be used in two ways. Either:\n" " - call it with a single argument such as:\n" @@ -89,73 +82,66 @@ def show( "clinical_events.numeric_values)" ) - -def render( - element, - *other_elements, - head: int | None = None, - tail: int | None = None, -): - if len(other_elements) == 0: - # single ehrql element so we just display it - element_repr = repr(element) - elif hasattr(element, "__repr_related__"): - # multiple ehrql series so we combine them - element_repr = element.__repr_related__(*other_elements) - else: - raise TypeError( - "When render() is called on multiple series, the first argument must have a __repr_related__ attribute." - ) - - if head or tail: - element_repr = truncate_table(element_repr, head, tail) - return element_repr + print(rendered, file=sys.stderr) @contextlib.contextmanager def activate_debug_context(*, dummy_tables_path, render_function): - global DEBUG_QUERY_ENGINE - - # Record original methods - BaseFrame__repr__ = BaseFrame.__repr__ - BaseSeries__repr__ = BaseSeries.__repr__ - DateDifference__repr__ = DateDifference.__repr__ - Dataset__repr__ = Dataset.__repr__ - - query_engine = DebugQueryEngine(dummy_tables_path) - DEBUG_QUERY_ENGINE = query_engine - - def repr_ehrql(obj, template=None): - return query_engine.evaluate(obj)._render_(render_function) - - def repr_related_ehrql(*series_args): - return render_function( - related_columns_to_records( - [query_engine.evaluate(series) for series in series_args] + global DEBUG_CONTEXT + DEBUG_CONTEXT = DebugContext( + query_engine=DebugQueryEngine(dummy_tables_path), + render_function=render_function, + ) + try: + yield DEBUG_CONTEXT # yield the context for testing convenience + finally: + DEBUG_CONTEXT = None + + +@dataclass +class DebugContext: + """Record the currently active query engine and render function.""" + + query_engine: DebugQueryEngine + render_function: typing.Callable + + def render( + self, + element, + *other_elements, + head: int | None = None, + tail: int | None = None, + ): + elements = [element, *other_elements] + + is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0 + is_multiple_series_same_domain = len( + elements + ) > 1 and elements_are_related_series(elements) + + # We throw an error unless called with: + # 1. A single ehrql object (series, frame, dataset) + # 2. Multiple series drawn from the same domain + if not (is_single_ehrql_object | is_multiple_series_same_domain): + raise TypeError("All elements must be in the same domain") + + if len(other_elements) == 0: + # single ehrql element so we just display it + rendered = self.query_engine.evaluate(element)._render_( + self.render_function + ) + else: + # multiple ehrql series so we combine them + rendered = self.render_function( + related_columns_to_records( + [self.query_engine.evaluate(element) for element in elements] + ) ) - ) - - # Temporarily overwrite __repr__ methods to display contents - BaseFrame.__repr__ = repr_ehrql - BaseSeries.__repr__ = repr_ehrql - DateDifference.__repr__ = repr_ehrql - Dataset.__repr__ = repr_ehrql - # Add additional method for displaying related series together - BaseSeries.__repr_related__ = repr_related_ehrql - DateDifference.__repr_related__ = repr_related_ehrql + if head or tail: + rendered = truncate_table(rendered, head, tail) - try: - yield - finally: - DEBUG_QUERY_ENGINE = None - # Restore original methods - BaseFrame.__repr__ = BaseFrame__repr__ - BaseSeries.__repr__ = BaseSeries__repr__ - DateDifference.__repr__ = DateDifference__repr__ - Dataset.__repr__ = Dataset__repr__ - del BaseSeries.__repr_related__ - del DateDifference.__repr_related__ + return rendered def elements_are_related_series(elements): diff --git a/ehrql/loaders.py b/ehrql/loaders.py index 1ea4536e4..4c4cea2bb 100644 --- a/ehrql/loaders.py +++ b/ehrql/loaders.py @@ -324,6 +324,12 @@ def load_definition_unsafe( def load_module(module_path, user_args=()): + """Load a module containing an ehrql dataset definition. + + Renders the serialized query model as json to the callers stdout, and any + other output to stderr. + """ + # Taken from the official recipe for importing a module from a file path: # https://docs.python.org/3.9/library/importlib.html#importing-a-source-file-directly spec = importlib.util.spec_from_file_location(module_path.stem, module_path) diff --git a/ehrql/quiz.py b/ehrql/quiz.py index beb9355cb..5a5a2fa60 100644 --- a/ehrql/quiz.py +++ b/ehrql/quiz.py @@ -276,8 +276,8 @@ def set_dummy_tables_path(self, path): # itself, regardless of how the quiz script was invoked. This kind of global # nastiness seems tolerable in the context of the quiz, and worth it for # avoiding potential confusion. - if ehrql.debugger.DEBUG_QUERY_ENGINE is not None: - ehrql.debugger.DEBUG_QUERY_ENGINE.dsn = path + if ehrql.debugger.DEBUG_CONTEXT is not None: + ehrql.debugger.DEBUG_CONTEXT.query_engine.dsn = path def __setitem__(self, index, question): question.index = index diff --git a/tests/unit/test_debugger.py b/tests/unit/test_debugger.py index 2a3e97496..d4d28af5e 100644 --- a/tests/unit/test_debugger.py +++ b/tests/unit/test_debugger.py @@ -9,7 +9,6 @@ activate_debug_context, elements_are_related_series, related_patient_columns_to_records, - render, ) from ehrql.query_engines.in_memory_database import PatientColumn from ehrql.tables import EventFrame, PatientFrame, Series, table @@ -24,7 +23,7 @@ def date_serializer(obj): def test_show(capsys): expected_output = textwrap.dedent( """ - Show line 32: + Show line 31: """ ).strip() @@ -37,7 +36,7 @@ def test_show(capsys): def test_show_with_label(capsys): expected_output = textwrap.dedent( """ - Show line 45: Number + Show line 44: Number """ ).strip() @@ -81,48 +80,6 @@ def test_related_patient_columns_to_records_full_join(): assert r == r_expected -def test_render_formatted_table(): - expected_output = textwrap.dedent( - """ - patient_id | value - ------------------+------------------ - 1 | 101 - 2 | 201 - """ - ).strip() - - c = PatientColumn.parse( - """ - 1 | 101 - 2 | 201 - """ - ) - assert render(c).strip() == expected_output - - -def test_render_truncated_table(): - expected_output = textwrap.dedent( - """ - patient_id | value - ------------------+------------------ - 1 | 101 - ... | ... - 4 | 401 - """ - ).strip() - - c = PatientColumn.parse( - """ - 1 | 101 - 2 | 201 - 3 | 301 - 4 | 401 - """ - ) - - assert render(c, head=1, tail=1) == expected_output - - @table class patients(PatientFrame): date_of_birth = Series(date) @@ -218,8 +175,8 @@ def test_activate_debug_context(dummy_tables_path, expression, contents): with activate_debug_context( dummy_tables_path=dummy_tables_path, render_function=lambda value: repr(list(value)), - ): - assert repr(expression) == repr(contents) + ) as ctx: + assert ctx.render(expression) == repr(contents) @pytest.mark.parametrize( @@ -243,8 +200,8 @@ def test_repr_related_patient_series(dummy_tables_path): render_function=lambda value: json.dumps( list(value), indent=4, default=date_serializer ), - ): - rendered = render( + ) as ctx: + rendered = ctx.render( patients.date_of_birth, patients.sex, events.count_for_patient(), @@ -274,8 +231,8 @@ def test_repr_related_event_series(dummy_tables_path): render_function=lambda value: json.dumps( list(value), indent=4, default=date_serializer ), - ): - rendered = render(events.date, events.code, events.test_result) + ) as ctx: + rendered = ctx.render(events.date, events.code, events.test_result) assert json.loads(rendered) == [ { "patient_id": 1, @@ -305,8 +262,8 @@ def test_repr_date_difference(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, render_function=lambda value: json.dumps(list(value), indent=4), - ): - rendered = render(patients.date_of_death - events.date) + ) as ctx: + rendered = ctx.render(patients.date_of_death - events.date) assert json.loads(rendered) == [ {"patient_id": 1, "row_id": 1, "value": ""}, {"patient_id": 1, "row_id": 2, "value": ""}, @@ -318,8 +275,8 @@ def test_repr_related_date_difference_patient_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, render_function=lambda value: json.dumps(list(value), indent=4), - ): - rendered = render( + ) as ctx: + rendered = ctx.render( "2024-01-01" - patients.date_of_birth, patients.sex, ) @@ -333,8 +290,8 @@ def test_repr_related_date_difference_event_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, render_function=lambda value: json.dumps(list(value), indent=4), - ): - rendered = render( + ) as ctx: + rendered = ctx.render( events.date - patients.date_of_birth, events.code, ) @@ -386,7 +343,7 @@ def test_show_does_not_raise_error_for_series_from_same_domain( def test_show_not_run_outside_debug_context(capsys): expected_output = textwrap.dedent( """ - Show line 394: + Show line 351: - show() ignored because we're not running in debug mode """ ).strip() @@ -394,8 +351,3 @@ def test_show_not_run_outside_debug_context(capsys): show(patients.date_of_birth, patients.sex) captured = capsys.readouterr() assert captured.err.strip() == expected_output, captured.err - - -def test_render_multiple_without_repr_related_errors(): - with pytest.raises(TypeError): - render("hello", "goodbye") diff --git a/tests/unit/test_quiz.py b/tests/unit/test_quiz.py index 7989be975..6517757c3 100644 --- a/tests/unit/test_quiz.py +++ b/tests/unit/test_quiz.py @@ -339,9 +339,9 @@ def test_set_dummy_tables_path_in_debug_context(): ): questions = quiz.Questions() questions.set_dummy_tables_path("bar") - assert debugger.DEBUG_QUERY_ENGINE.dsn.name == "bar" + assert debugger.DEBUG_CONTEXT.query_engine.dsn.name == "bar" # This should be unset outside of the context manager - assert debugger.DEBUG_QUERY_ENGINE is None + assert debugger.DEBUG_CONTEXT is None def test_hint(capfd): From dca3138f8be13b8386271981329c4143bc3e33c8 Mon Sep 17 00:00:00 2001 From: bloodearnest Date: Thu, 16 Jan 2025 11:58:37 +0000 Subject: [PATCH 2/4] Use exec() to make the line numbers in show() tests independent of the test file line numbers --- tests/unit/test_debugger.py | 39 +++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_debugger.py b/tests/unit/test_debugger.py index d4d28af5e..3a58b52a0 100644 --- a/tests/unit/test_debugger.py +++ b/tests/unit/test_debugger.py @@ -23,12 +23,21 @@ def date_serializer(obj): def test_show(capsys): expected_output = textwrap.dedent( """ - Show line 31: + Show line 3: """ ).strip() - show("Hello") + exec( + textwrap.dedent( + """ + # line 2 + show("Hello") + # line 4 + """ + ) + ) + captured = capsys.readouterr() assert captured.err.strip().startswith(expected_output), captured.err @@ -36,12 +45,21 @@ def test_show(capsys): def test_show_with_label(capsys): expected_output = textwrap.dedent( """ - Show line 44: Number + Show line 3: Number """ ).strip() - show(14, label="Number") + exec( + textwrap.dedent( + """ + # line 2 + show(14, label="Number") + # line 4 + """ + ) + ) + captured = capsys.readouterr() assert captured.err.strip().startswith(expected_output), captured.err @@ -343,11 +361,20 @@ def test_show_does_not_raise_error_for_series_from_same_domain( def test_show_not_run_outside_debug_context(capsys): expected_output = textwrap.dedent( """ - Show line 351: + Show line 3: - show() ignored because we're not running in debug mode """ ).strip() - show(patients.date_of_birth, patients.sex) + exec( + textwrap.dedent( + """ + # line 2 + show(patients.date_of_birth, patients.sex) + # line 4 + """ + ) + ) + captured = capsys.readouterr() assert captured.err.strip() == expected_output, captured.err From d00b3e31f0a4230aaa944752b733142d017eba96 Mon Sep 17 00:00:00 2001 From: bloodearnest Date: Thu, 16 Jan 2025 13:11:08 +0000 Subject: [PATCH 3/4] Refactor debug output truncation Previously, we did the truncation post render, probably in order not to mess with the default repr, which cannot be parametrised with head/tail arguments. However, that is gone now, and we no longer render via a parameterless repr, so we refactor the truncation at render time. This has several advantages - can use the same logic for both html and ascii rendering - only done in one place - more efficient, as we do not render the truncated rows at all. This includes the following changes: - head and tail are arguments to the render() function call. The user interface in show() allows them to be `int | None`, but we force them to integer values in the implementation, with 0 meaning None, as both are falsey. This necessetated some changes to not use _render_ functions, and call to_records() directly, so we could pass the head/tail args at render time. - head and tail logic is factored out into a shared function, and handles ellipsis logic consistently. - The render function's `records` parameter was annotated as list[dict], but it was actually a generator. We assume that it is a list, and use indexing for simple head/tail splicing, and fix the call sites to list() the the generator. There's no way to do tail w/o exhausting the generator anyway, and we were rendering the full thing before, so this is no less efficnet, but is simpler to read. - fixed the debug tests to always use a json renderer for testing, for consistency. --- ehrql/debugger.py | 20 ++- ehrql/query_engines/debug.py | 4 +- ehrql/query_engines/in_memory_database.py | 8 +- ehrql/renderers.py | 147 +++++++--------------- tests/unit/test_debugger.py | 55 ++++---- tests/unit/test_renderers.py | 79 ++++-------- 6 files changed, 108 insertions(+), 205 deletions(-) diff --git a/ehrql/debugger.py b/ehrql/debugger.py index cf457d095..000e92fb4 100644 --- a/ehrql/debugger.py +++ b/ehrql/debugger.py @@ -12,7 +12,6 @@ ) from ehrql.query_language import BaseSeries, Dataset, DateDifference from ehrql.query_model import nodes as qm -from ehrql.renderers import truncate_table from ehrql.utils.docs_utils import exclude_from_docs @@ -64,7 +63,9 @@ def show( return try: - rendered = DEBUG_CONTEXT.render(element, *other_elements, head=head, tail=tail) + rendered = DEBUG_CONTEXT.render( + element, *other_elements, head=head or 0, tail=tail or 0 + ) except TypeError: # raise more helpful show() specific error raise TypeError( @@ -109,8 +110,8 @@ def render( self, element, *other_elements, - head: int | None = None, - tail: int | None = None, + head: int = 0, + tail: int = 0, ): elements = [element, *other_elements] @@ -127,21 +128,18 @@ def render( if len(other_elements) == 0: # single ehrql element so we just display it - rendered = self.query_engine.evaluate(element)._render_( - self.render_function + evaluated = list( + self.query_engine.evaluate(element).to_records(convert_null=True) ) else: # multiple ehrql series so we combine them - rendered = self.render_function( + evaluated = list( related_columns_to_records( [self.query_engine.evaluate(element) for element in elements] ) ) - if head or tail: - rendered = truncate_table(rendered, head, tail) - - return rendered + return self.render_function(evaluated, head, tail) def elements_are_related_series(elements): diff --git a/ehrql/query_engines/debug.py b/ehrql/query_engines/debug.py index 67b57cc85..e8130eca2 100644 --- a/ehrql/query_engines/debug.py +++ b/ehrql/query_engines/debug.py @@ -57,5 +57,5 @@ class EmptyDataset: """This class exists to render something nice when a user tries to inspect a dataset with no columns with debugger.show().""" - def _render_(self, render_fn): - return render_fn([{"patient_id": ""}]) + def to_records(self, convert_null=False): + return [{"patient_id": ""}] diff --git a/ehrql/query_engines/in_memory_database.py b/ehrql/query_engines/in_memory_database.py index 01f9a6314..fa6cd7697 100644 --- a/ehrql/query_engines/in_memory_database.py +++ b/ehrql/query_engines/in_memory_database.py @@ -93,7 +93,7 @@ def __repr__(self): return self._render_(DISPLAY_RENDERERS["ascii"]) def _render_(self, render_fn): - return render_fn(self.to_records(convert_null=True)) + return render_fn(list(self.to_records(convert_null=True))) def __getitem__(self, name): return self.name_to_col[name] @@ -179,7 +179,7 @@ def __repr__(self): return self._render_(DISPLAY_RENDERERS["ascii"]) def _render_(self, render_fn): - return render_fn(self.to_records(convert_null=True)) + return render_fn(list(self.to_records(convert_null=True))) def __getitem__(self, name): return self.name_to_col[name] @@ -257,7 +257,7 @@ def __repr__(self): return self._render_(DISPLAY_RENDERERS["ascii"]) def _render_(self, render_fn): - return render_fn(self.to_records(convert_null=True)) + return render_fn(list(self.to_records(convert_null=True))) def __getitem__(self, patient): return self.patient_to_value.get(patient, self.default) @@ -318,7 +318,7 @@ def __repr__(self): return self._render_(DISPLAY_RENDERERS["ascii"]) def _render_(self, render_fn): - return render_fn(self.to_records(convert_null=True)) + return render_fn(list(self.to_records(convert_null=True))) def __getitem__(self, patient): return self.patient_to_rows.get(patient, Rows({})) diff --git a/ehrql/renderers.py b/ehrql/renderers.py index 7fb53d765..df84d867c 100644 --- a/ehrql/renderers.py +++ b/ehrql/renderers.py @@ -1,134 +1,75 @@ -import re +import typing START_MARKER = "" END_MARKER = "" -def records_to_html_table(records: list[dict]): - rows = [] - headers_written = False - - for record in records: - if not headers_written: - headers = "".join([f"{header}" for header in record.keys()]) - headers_written = True - row = "".join(f"{val}" for val in record.values()) - rows.append(f"{row}") - rows = "".join(rows) - - return f"{START_MARKER}{headers}{rows}
{END_MARKER}" - - -def records_to_ascii_table(records: list[dict]): - width = 17 - lines = [] - headers_written = False - for record in records: - if not headers_written: - lines.append(" | ".join(name.ljust(width) for name in record.keys())) - lines.append("-+-".join("-" * width for _ in record.keys())) - headers_written = True - lines.append(" | ".join(str(value).ljust(width) for value in record.values())) - return "\n".join(line.strip() for line in lines) + "\n" - - -def _truncate_html_table(table_repr: str, head: int | None, tail: int | None): - """ - Truncate an html table string to the first/last N rows, with a row of ... - values to indicate where it's been truncated - """ - regex = re.compile( - rf"(?P^{START_MARKER}.*)(?P.*<\/tr>)(?P<\/tbody>.*<\/table>{END_MARKER})" - ) - match = regex.match(table_repr) - if match is None: - # if we can't parse the table, return None and let the fallback handle it - return - - start = match.group("start") - rows = match.group("rows") - end = match.group("end") +def headtail(sequence: typing.Sequence, head: int = 0, tail: int = 0): + if head == tail == 0: + return sequence, False, [] # Check we have enough rows to truncate to head and tail - if rows.count("") <= ((head or 0) + (tail or 0)): - return table_repr + if len(sequence) <= head + tail: + return sequence, False, [] - # split row on tokens, remove any empty strings (this will lose the first of - # each row, but we'll add it in again later) - rows = [row for row in rows.split("") if row] - # compose an "ellipsis row" to mark the place of truncated rows - td_count = rows[0].count("' * td_count}" + head_rows = [] + tail_rows = [] - # Build the list of rows we need to include, with ellipsis rows where necessary - truncated_rows = [] + if head: + head_rows = sequence[:head] - head_rows = rows[:head] if head is not None else [ellipsis_row] - truncated_rows.extend(head_rows) + if tail: + tail_rows = sequence[-tail:] - if head is not None and tail is not None: - truncated_rows.append(ellipsis_row) + return head_rows, True, tail_rows - tail_rows = rows[-tail:] if tail is not None else [ellipsis_row] - truncated_rows.extend(tail_rows) - # re-join the truncated rows with - truncated_rows = "" + "".join(truncated_rows) - return start + truncated_rows + end +def records_to_html_table(records: list[dict], head: int = 0, tail: int = 0): + rows = [] + headers = "".join([f"" for header in records[0].keys()]) + head_rows, ellipsis, tail_rows = headtail(records, head, tail) + def html_row(row): + columns = "".join(f"" for v in row.values()) + return f"{columns}" -def _truncate_lines( - table_repr: str, headers: int = 0, head: int | None = None, tail: int | None = None -): - table_rows = [row for row in table_repr.split("\n") if row] + for row in head_rows: + rows.append(html_row(row)) - # Check we have enough rows to truncate to head and tail - if len(table_rows) <= (headers + (head or 0) + (tail or 0)): - return table_repr + if ellipsis: + ellipsis_columns = "" * len(records[0]) + rows.append(f"{ellipsis_columns}") - # compose an "ellipsis row" to mark the place of truncated rows - cell_count = table_rows[0].count("|") + 1 - ellipsis_row = " | ".join("...".ljust(17) for i in range(cell_count)).strip() + for row in tail_rows: + rows.append(html_row(row)) - # Build the list of rows we need to include, with ellipsis rows where necessary - truncated_rows = table_rows[:headers] + html_rows = "".join(rows) - head_rows = ( - table_rows[headers : headers + head] if head is not None else [ellipsis_row] - ) - truncated_rows.extend(head_rows) + return f"{START_MARKER}
") - ellipsis_row = f"{'...
{header}{v}
{headers}{html_rows}
{END_MARKER}" - if head is not None and tail is not None: - truncated_rows.append(ellipsis_row) - tail_rows = table_rows[-tail:] if tail is not None else [ellipsis_row] - truncated_rows.extend(tail_rows) +def records_to_ascii_table(records: list[dict], head: int = 0, tail: int = 0): + head_rows, ellipsis, tail_rows = headtail(records, head, tail) + width = 17 + lines = [] - return "\n".join(truncated_rows) + headers = records[0].keys() + lines.append(" | ".join(name.ljust(width) for name in headers)) + lines.append("-+-".join("-" * width for _ in headers)) -def truncate_table(table_repr: str, head: int | None, tail: int | None): - """ - Take a table repr (ascii or html format) and truncate it to show only the - first and/or last N rows. - """ - if head is None and tail is None: - return table_repr + for line in head_rows: + lines.append(" | ".join(str(v).ljust(width) for v in line.values())) - truncated_repr = None + if ellipsis: + ellipsis_row = " | ".join("...".ljust(width) for _ in headers) + lines.append(ellipsis_row) - if "" in table_repr: - truncated_repr = _truncate_html_table(table_repr, head=head, tail=tail) - elif "---+---" in table_repr: - truncated_repr = _truncate_lines(table_repr, headers=2, head=head, tail=tail) + for line in tail_rows: + lines.append(" | ".join(str(v).ljust(width) for v in line.values())) - # if we didn't detect either an ascii or html table, or the html regex - # didn't match as expected, fall back to a simple truncation of lines using - # line breaks. - if truncated_repr is None: - truncated_repr = _truncate_lines(table_repr, head=head, tail=tail) - return truncated_repr + return "\n".join(line.strip() for line in lines) DISPLAY_RENDERERS = { diff --git a/tests/unit/test_debugger.py b/tests/unit/test_debugger.py index 3a58b52a0..b04d88c39 100644 --- a/tests/unit/test_debugger.py +++ b/tests/unit/test_debugger.py @@ -20,6 +20,11 @@ def date_serializer(obj): raise TypeError("Type not serializable") # pragma: no cover +def json_render_function(sequence, head=0, tail=0): + """Render as JSON, useful for testing.""" + return json.dumps(sequence, indent=4, default=date_serializer) + + def test_show(capsys): expected_output = textwrap.dedent( """ @@ -67,7 +72,7 @@ def test_show_with_label(capsys): def test_show_fails_for_non_ehrql_object(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: value, + render_function=json_render_function, ): with pytest.raises(TypeError): show("Hello") @@ -152,14 +157,14 @@ def dummy_tables_path(tmp_path_factory): [ { "patient_id": 1, - "date_of_birth": date(1970, 1, 1), + "date_of_birth": "1970-01-01", "date_of_death": "", "sex": "male", }, { "patient_id": 2, - "date_of_birth": date(1980, 1, 1), - "date_of_death": date(2020, 1, 1), + "date_of_birth": "1980-01-01", + "date_of_death": "2020-01-01", "sex": "female", }, ], @@ -167,8 +172,8 @@ def dummy_tables_path(tmp_path_factory): ( patients.date_of_birth, [ - {"patient_id": 1, "value": date(1970, 1, 1)}, - {"patient_id": 2, "value": date(1980, 1, 1)}, + {"patient_id": 1, "value": "1970-01-01"}, + {"patient_id": 2, "value": "1980-01-01"}, ], ), ( @@ -178,12 +183,12 @@ def dummy_tables_path(tmp_path_factory): dod=patients.date_of_death, ), [ - {"patient_id": 1, "dob": date(1970, 1, 1), "count": 2, "dod": ""}, + {"patient_id": 1, "dob": "1970-01-01", "count": 2, "dod": ""}, { "patient_id": 2, - "dob": date(1980, 1, 1), + "dob": "1980-01-01", "count": 1, - "dod": date(2020, 1, 1), + "dod": "2020-01-01", }, ], ), @@ -192,9 +197,9 @@ def dummy_tables_path(tmp_path_factory): def test_activate_debug_context(dummy_tables_path, expression, contents): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: repr(list(value)), + render_function=json_render_function, ) as ctx: - assert ctx.render(expression) == repr(contents) + assert json.loads(ctx.render(expression)) == contents @pytest.mark.parametrize( @@ -212,12 +217,10 @@ def test_elements_are_related_series(elements, expected): assert elements_are_related_series(elements) == expected -def test_repr_related_patient_series(dummy_tables_path): +def test_render_related_patient_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps( - list(value), indent=4, default=date_serializer - ), + render_function=json_render_function, ) as ctx: rendered = ctx.render( patients.date_of_birth, @@ -243,12 +246,10 @@ def test_repr_related_patient_series(dummy_tables_path): ] -def test_repr_related_event_series(dummy_tables_path): +def test_render_related_event_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps( - list(value), indent=4, default=date_serializer - ), + render_function=json_render_function, ) as ctx: rendered = ctx.render(events.date, events.code, events.test_result) assert json.loads(rendered) == [ @@ -276,10 +277,10 @@ def test_repr_related_event_series(dummy_tables_path): ] -def test_repr_date_difference(dummy_tables_path): +def test_render_date_difference(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps(list(value), indent=4), + render_function=json_render_function, ) as ctx: rendered = ctx.render(patients.date_of_death - events.date) assert json.loads(rendered) == [ @@ -289,10 +290,10 @@ def test_repr_date_difference(dummy_tables_path): ] -def test_repr_related_date_difference_patient_series(dummy_tables_path): +def test_render_related_date_difference_patient_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps(list(value), indent=4), + render_function=json_render_function, ) as ctx: rendered = ctx.render( "2024-01-01" - patients.date_of_birth, @@ -304,10 +305,10 @@ def test_repr_related_date_difference_patient_series(dummy_tables_path): ] -def test_repr_related_date_difference_event_series(dummy_tables_path): +def test_render_related_date_difference_event_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps(list(value), indent=4), + render_function=json_render_function, ) as ctx: rendered = ctx.render( events.date - patients.date_of_birth, @@ -332,7 +333,7 @@ def test_repr_related_date_difference_event_series(dummy_tables_path): def test_show_fails_for_mismatched_inputs(example_input): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: value, + render_function=json_render_function, ): with pytest.raises(TypeError): assert show(*example_input) @@ -353,7 +354,7 @@ def test_show_does_not_raise_error_for_series_from_same_domain( ): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: value, + render_function=json_render_function, ): show(example_input[0], *example_input[1:]) diff --git a/tests/unit/test_renderers.py b/tests/unit/test_renderers.py index 81f87013a..c81eb96c5 100644 --- a/tests/unit/test_renderers.py +++ b/tests/unit/test_renderers.py @@ -7,7 +7,7 @@ PatientColumn, PatientTable, ) -from ehrql.renderers import DISPLAY_RENDERERS, truncate_table +from ehrql.renderers import DISPLAY_RENDERERS TABLE = PatientTable.parse( @@ -41,7 +41,7 @@ def test_render_table(render_format): "" "
" "" - "" + "" "" "" "" @@ -54,7 +54,7 @@ def test_render_table(render_format): "" ), } - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()).strip() + rendered = DISPLAY_RENDERERS[render_format](list(TABLE.to_records())).strip() assert rendered == expected_output[render_format], rendered @@ -73,7 +73,7 @@ def test_render_column(render_format): "" "
patient_idi1i2
patient_idi1i2
1101111
" "" - "" + "" "" "" "" @@ -90,7 +90,7 @@ def test_render_column(render_format): 2 | 201 """ ) - rendered = DISPLAY_RENDERERS[render_format](c.to_records()).strip() + rendered = DISPLAY_RENDERERS[render_format](list(c.to_records())).strip() assert rendered == expected_output[render_format], rendered @@ -110,20 +110,19 @@ def test_render_table_head(render_format): "" "
patient_idvalue
patient_idvalue
1101
" "" - "" + "" "" "" "" "" - "" + "" "" "
patient_idi1i2
patient_idi1i2
1101111
2201211
.........
" "" ), } - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()) - truncated = truncate_table(rendered, head=2, tail=None) + truncated = DISPLAY_RENDERERS[render_format](list(TABLE.to_records()), head=2) assert truncated == expected_output[render_format], truncated @@ -143,10 +142,10 @@ def test_render_table_tail(render_format): "" "" "" - "" + "" "" "" - "" + "" "" "" "" @@ -155,8 +154,7 @@ def test_render_table_tail(render_format): ), } - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()) - truncated = truncate_table(rendered, head=None, tail=2) + truncated = DISPLAY_RENDERERS[render_format](list(TABLE.to_records()), tail=2) assert truncated == expected_output[render_format], truncated @@ -178,12 +176,12 @@ def test_render_table_head_and_tail(render_format): "" "
patient_idi1i2
patient_idi1i2
.........
4401411
5501511
" "" - "" + "" "" "" "" "" - "" + "" "" "" "" @@ -192,16 +190,15 @@ def test_render_table_head_and_tail(render_format): ), } - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()) - truncated = truncate_table(rendered, head=2, tail=2) + truncated = DISPLAY_RENDERERS[render_format]( + list(TABLE.to_records()), head=2, tail=2 + ) assert truncated == expected_output[render_format], truncated @pytest.mark.parametrize( "render_format,head_tail", - list( - product(["ascii", "html"], [(None, None), (2, 3), (5, None), (None, 6), (3, 3)]) - ), + list(product(["ascii"], [(0, 0), (2, 3), (5, 0), (0, 6), (3, 3)])), ) def test_render_table_bad_head_tail(render_format, head_tail): expected_output = { @@ -220,7 +217,7 @@ def test_render_table_bad_head_tail(render_format, head_tail): "" "
patient_idi1i2
patient_idi1i2
1101111
2201211
.........
4401411
5501511
" "" - "" + "" "" "" "" @@ -234,41 +231,7 @@ def test_render_table_bad_head_tail(render_format, head_tail): ), } head, tail = head_tail - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()) - truncated = truncate_table(rendered, head=head, tail=tail).strip() - assert truncated == expected_output[render_format], (truncated, head, tail) - - -def test_render_head_and_tail_not_a_table(): - expected_output = textwrap.dedent( - """ - a - b - ... - d - e - """ - ).strip() - - input_string = "\n".join(["a", "b", "c", "d", "e"]) - truncated = truncate_table(input_string, head=2, tail=2).strip() - assert truncated == expected_output, truncated - - -def test_truncate_table_bad_html(): - # If we can't parse something that looks like an html - # table as expected, we fall back to the basic line truncator - bad_html = ( - "
patient_idi1i2
patient_idi1i2
1101111
\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "
patient_idi1i2
" + truncated = DISPLAY_RENDERERS[render_format]( + list(TABLE.to_records()), head=head, tail=tail ) - - expected = "\n\n..." - - truncated = truncate_table(bad_html, head=2, tail=None) - assert truncated == expected, truncated + assert truncated == expected_output[render_format], (truncated, head, tail) From 9761f4b829778dd13a1ce232dbb1eda6a544e549 Mon Sep 17 00:00:00 2001 From: bloodearnest Date: Fri, 17 Jan 2025 15:03:30 +0000 Subject: [PATCH 4/4] Handle renderer case where records is empty list --- ehrql/renderers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ehrql/renderers.py b/ehrql/renderers.py index df84d867c..0114e9be9 100644 --- a/ehrql/renderers.py +++ b/ehrql/renderers.py @@ -27,7 +27,11 @@ def headtail(sequence: typing.Sequence, head: int = 0, tail: int = 0): def records_to_html_table(records: list[dict], head: int = 0, tail: int = 0): rows = [] - headers = "".join([f"" for header in records[0].keys()]) + headers = ( + "".join([f"" for header in records[0].keys()]) + if records + else "" + ) head_rows, ellipsis, tail_rows = headtail(records, head, tail) def html_row(row): @@ -54,7 +58,7 @@ def records_to_ascii_table(records: list[dict], head: int = 0, tail: int = 0): width = 17 lines = [] - headers = records[0].keys() + headers = records[0].keys() if records else [] lines.append(" | ".join(name.ljust(width) for name in headers)) lines.append("-+-".join("-" * width for _ in headers))
{header}{header}