diff --git a/ehrql/debugger.py b/ehrql/debugger.py index 03b950fb4..000e92fb4 100644 --- a/ehrql/debugger.py +++ b/ehrql/debugger.py @@ -1,6 +1,8 @@ import contextlib import inspect import sys +import typing +from dataclasses import dataclass from ehrql.query_engines.debug import DebugQueryEngine from ehrql.query_engines.in_memory_database import ( @@ -8,13 +10,12 @@ PatientColumn, render_value, ) -from ehrql.query_language import BaseFrame, BaseSeries, Dataset, DateDifference +from ehrql.query_language import BaseSeries, Dataset, DateDifference from ehrql.query_model import nodes as qm -from ehrql.renderers import truncate_table from ehrql.utils.docs_utils import exclude_from_docs -DEBUG_QUERY_ENGINE = None +DEBUG_CONTEXT = None @exclude_from_docs @@ -54,26 +55,19 @@ def show( label = f" {label}" if label else "" print(f"Show line {line_no}:{label}", file=sys.stderr) - if DEBUG_QUERY_ENGINE is None: + if DEBUG_CONTEXT is None: # We're not in debug mode so ignore print( " - show() ignored because we're not running in debug mode", file=sys.stderr ) return - elements = [element, *other_elements] - - is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0 - is_multiple_series_same_domain = len(elements) > 1 and elements_are_related_series( - elements - ) - - # We throw an error unless show() is called with: - # 1. A single ehrql object (series, frame, dataset) - # 2. Multiple series drawn from the same domain - if is_single_ehrql_object | is_multiple_series_same_domain: - print(render(element, *other_elements, head=head, tail=tail), file=sys.stderr) - else: + try: + rendered = DEBUG_CONTEXT.render( + element, *other_elements, head=head or 0, tail=tail or 0 + ) + except TypeError: + # raise more helpful show() specific error raise TypeError( "\n\nshow() can be used in two ways. Either:\n" " - call it with a single argument such as:\n" @@ -89,73 +83,63 @@ def show( "clinical_events.numeric_values)" ) - -def render( - element, - *other_elements, - head: int | None = None, - tail: int | None = None, -): - if len(other_elements) == 0: - # single ehrql element so we just display it - element_repr = repr(element) - elif hasattr(element, "__repr_related__"): - # multiple ehrql series so we combine them - element_repr = element.__repr_related__(*other_elements) - else: - raise TypeError( - "When render() is called on multiple series, the first argument must have a __repr_related__ attribute." - ) - - if head or tail: - element_repr = truncate_table(element_repr, head, tail) - return element_repr + print(rendered, file=sys.stderr) @contextlib.contextmanager def activate_debug_context(*, dummy_tables_path, render_function): - global DEBUG_QUERY_ENGINE - - # Record original methods - BaseFrame__repr__ = BaseFrame.__repr__ - BaseSeries__repr__ = BaseSeries.__repr__ - DateDifference__repr__ = DateDifference.__repr__ - Dataset__repr__ = Dataset.__repr__ - - query_engine = DebugQueryEngine(dummy_tables_path) - DEBUG_QUERY_ENGINE = query_engine - - def repr_ehrql(obj, template=None): - return query_engine.evaluate(obj)._render_(render_function) - - def repr_related_ehrql(*series_args): - return render_function( - related_columns_to_records( - [query_engine.evaluate(series) for series in series_args] - ) - ) - - # Temporarily overwrite __repr__ methods to display contents - BaseFrame.__repr__ = repr_ehrql - BaseSeries.__repr__ = repr_ehrql - DateDifference.__repr__ = repr_ehrql - Dataset.__repr__ = repr_ehrql - - # Add additional method for displaying related series together - BaseSeries.__repr_related__ = repr_related_ehrql - DateDifference.__repr_related__ = repr_related_ehrql - + global DEBUG_CONTEXT + DEBUG_CONTEXT = DebugContext( + query_engine=DebugQueryEngine(dummy_tables_path), + render_function=render_function, + ) try: - yield + yield DEBUG_CONTEXT # yield the context for testing convenience finally: - DEBUG_QUERY_ENGINE = None - # Restore original methods - BaseFrame.__repr__ = BaseFrame__repr__ - BaseSeries.__repr__ = BaseSeries__repr__ - DateDifference.__repr__ = DateDifference__repr__ - Dataset.__repr__ = Dataset__repr__ - del BaseSeries.__repr_related__ - del DateDifference.__repr_related__ + DEBUG_CONTEXT = None + + +@dataclass +class DebugContext: + """Record the currently active query engine and render function.""" + + query_engine: DebugQueryEngine + render_function: typing.Callable + + def render( + self, + element, + *other_elements, + head: int = 0, + tail: int = 0, + ): + elements = [element, *other_elements] + + is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0 + is_multiple_series_same_domain = len( + elements + ) > 1 and elements_are_related_series(elements) + + # We throw an error unless called with: + # 1. A single ehrql object (series, frame, dataset) + # 2. Multiple series drawn from the same domain + if not (is_single_ehrql_object | is_multiple_series_same_domain): + raise TypeError("All elements must be in the same domain") + + if len(other_elements) == 0: + # single ehrql element so we just display it + evaluated = list( + self.query_engine.evaluate(element).to_records(convert_null=True) + ) + else: + # multiple ehrql series so we combine them + evaluated = list( + related_columns_to_records( + [self.query_engine.evaluate(element) for element in elements] + ) + ) + + return self.render_function(evaluated, head, tail) def elements_are_related_series(elements): diff --git a/ehrql/loaders.py b/ehrql/loaders.py index 1ea4536e4..4c4cea2bb 100644 --- a/ehrql/loaders.py +++ b/ehrql/loaders.py @@ -324,6 +324,12 @@ def load_definition_unsafe( def load_module(module_path, user_args=()): + """Load a module containing an ehrql dataset definition. + + Renders the serialized query model as json to the callers stdout, and any + other output to stderr. + """ + # Taken from the official recipe for importing a module from a file path: # https://docs.python.org/3.9/library/importlib.html#importing-a-source-file-directly spec = importlib.util.spec_from_file_location(module_path.stem, module_path) diff --git a/ehrql/query_engines/debug.py b/ehrql/query_engines/debug.py index 67b57cc85..e8130eca2 100644 --- a/ehrql/query_engines/debug.py +++ b/ehrql/query_engines/debug.py @@ -57,5 +57,5 @@ class EmptyDataset: """This class exists to render something nice when a user tries to inspect a dataset with no columns with debugger.show().""" - def _render_(self, render_fn): - return render_fn([{"patient_id": ""}]) + def to_records(self, convert_null=False): + return [{"patient_id": ""}] diff --git a/ehrql/query_engines/in_memory_database.py b/ehrql/query_engines/in_memory_database.py index 01f9a6314..fa6cd7697 100644 --- a/ehrql/query_engines/in_memory_database.py +++ b/ehrql/query_engines/in_memory_database.py @@ -93,7 +93,7 @@ def __repr__(self): return self._render_(DISPLAY_RENDERERS["ascii"]) def _render_(self, render_fn): - return render_fn(self.to_records(convert_null=True)) + return render_fn(list(self.to_records(convert_null=True))) def __getitem__(self, name): return self.name_to_col[name] @@ -179,7 +179,7 @@ def __repr__(self): return self._render_(DISPLAY_RENDERERS["ascii"]) def _render_(self, render_fn): - return render_fn(self.to_records(convert_null=True)) + return render_fn(list(self.to_records(convert_null=True))) def __getitem__(self, name): return self.name_to_col[name] @@ -257,7 +257,7 @@ def __repr__(self): return self._render_(DISPLAY_RENDERERS["ascii"]) def _render_(self, render_fn): - return render_fn(self.to_records(convert_null=True)) + return render_fn(list(self.to_records(convert_null=True))) def __getitem__(self, patient): return self.patient_to_value.get(patient, self.default) @@ -318,7 +318,7 @@ def __repr__(self): return self._render_(DISPLAY_RENDERERS["ascii"]) def _render_(self, render_fn): - return render_fn(self.to_records(convert_null=True)) + return render_fn(list(self.to_records(convert_null=True))) def __getitem__(self, patient): return self.patient_to_rows.get(patient, Rows({})) diff --git a/ehrql/quiz.py b/ehrql/quiz.py index beb9355cb..5a5a2fa60 100644 --- a/ehrql/quiz.py +++ b/ehrql/quiz.py @@ -276,8 +276,8 @@ def set_dummy_tables_path(self, path): # itself, regardless of how the quiz script was invoked. This kind of global # nastiness seems tolerable in the context of the quiz, and worth it for # avoiding potential confusion. - if ehrql.debugger.DEBUG_QUERY_ENGINE is not None: - ehrql.debugger.DEBUG_QUERY_ENGINE.dsn = path + if ehrql.debugger.DEBUG_CONTEXT is not None: + ehrql.debugger.DEBUG_CONTEXT.query_engine.dsn = path def __setitem__(self, index, question): question.index = index diff --git a/ehrql/renderers.py b/ehrql/renderers.py index 7fb53d765..0114e9be9 100644 --- a/ehrql/renderers.py +++ b/ehrql/renderers.py @@ -1,134 +1,79 @@ -import re +import typing START_MARKER = "" END_MARKER = "" -def records_to_html_table(records: list[dict]): - rows = [] - headers_written = False - - for record in records: - if not headers_written: - headers = "".join([f"{header}" for header in record.keys()]) - headers_written = True - row = "".join(f"{val}" for val in record.values()) - rows.append(f"{row}") - rows = "".join(rows) - - return f"{START_MARKER}{headers}{rows}
{END_MARKER}" - - -def records_to_ascii_table(records: list[dict]): - width = 17 - lines = [] - headers_written = False - for record in records: - if not headers_written: - lines.append(" | ".join(name.ljust(width) for name in record.keys())) - lines.append("-+-".join("-" * width for _ in record.keys())) - headers_written = True - lines.append(" | ".join(str(value).ljust(width) for value in record.values())) - return "\n".join(line.strip() for line in lines) + "\n" - - -def _truncate_html_table(table_repr: str, head: int | None, tail: int | None): - """ - Truncate an html table string to the first/last N rows, with a row of ... - values to indicate where it's been truncated - """ - regex = re.compile( - rf"(?P^{START_MARKER}.*)(?P.*<\/tr>)(?P<\/tbody>.*<\/table>{END_MARKER})" - ) - match = regex.match(table_repr) - if match is None: - # if we can't parse the table, return None and let the fallback handle it - return - - start = match.group("start") - rows = match.group("rows") - end = match.group("end") +def headtail(sequence: typing.Sequence, head: int = 0, tail: int = 0): + if head == tail == 0: + return sequence, False, [] # Check we have enough rows to truncate to head and tail - if rows.count("") <= ((head or 0) + (tail or 0)): - return table_repr + if len(sequence) <= head + tail: + return sequence, False, [] - # split row on tokens, remove any empty strings (this will lose the first of - # each row, but we'll add it in again later) - rows = [row for row in rows.split("") if row] - # compose an "ellipsis row" to mark the place of truncated rows - td_count = rows[0].count("' * td_count}" + head_rows = [] + tail_rows = [] - # Build the list of rows we need to include, with ellipsis rows where necessary - truncated_rows = [] + if head: + head_rows = sequence[:head] - head_rows = rows[:head] if head is not None else [ellipsis_row] - truncated_rows.extend(head_rows) + if tail: + tail_rows = sequence[-tail:] - if head is not None and tail is not None: - truncated_rows.append(ellipsis_row) + return head_rows, True, tail_rows - tail_rows = rows[-tail:] if tail is not None else [ellipsis_row] - truncated_rows.extend(tail_rows) - # re-join the truncated rows with - truncated_rows = "" + "".join(truncated_rows) - return start + truncated_rows + end +def records_to_html_table(records: list[dict], head: int = 0, tail: int = 0): + rows = [] + headers = ( + "".join([f"" for header in records[0].keys()]) + if records + else "" + ) + head_rows, ellipsis, tail_rows = headtail(records, head, tail) + def html_row(row): + columns = "".join(f"" for v in row.values()) + return f"{columns}" -def _truncate_lines( - table_repr: str, headers: int = 0, head: int | None = None, tail: int | None = None -): - table_rows = [row for row in table_repr.split("\n") if row] + for row in head_rows: + rows.append(html_row(row)) - # Check we have enough rows to truncate to head and tail - if len(table_rows) <= (headers + (head or 0) + (tail or 0)): - return table_repr + if ellipsis: + ellipsis_columns = "" * len(records[0]) + rows.append(f"{ellipsis_columns}") - # compose an "ellipsis row" to mark the place of truncated rows - cell_count = table_rows[0].count("|") + 1 - ellipsis_row = " | ".join("...".ljust(17) for i in range(cell_count)).strip() + for row in tail_rows: + rows.append(html_row(row)) - # Build the list of rows we need to include, with ellipsis rows where necessary - truncated_rows = table_rows[:headers] + html_rows = "".join(rows) - head_rows = ( - table_rows[headers : headers + head] if head is not None else [ellipsis_row] - ) - truncated_rows.extend(head_rows) + return f"{START_MARKER}
") - ellipsis_row = f"{'...
{header}{v}
{headers}{html_rows}
{END_MARKER}" - if head is not None and tail is not None: - truncated_rows.append(ellipsis_row) - tail_rows = table_rows[-tail:] if tail is not None else [ellipsis_row] - truncated_rows.extend(tail_rows) +def records_to_ascii_table(records: list[dict], head: int = 0, tail: int = 0): + head_rows, ellipsis, tail_rows = headtail(records, head, tail) + width = 17 + lines = [] - return "\n".join(truncated_rows) + headers = records[0].keys() if records else [] + lines.append(" | ".join(name.ljust(width) for name in headers)) + lines.append("-+-".join("-" * width for _ in headers)) -def truncate_table(table_repr: str, head: int | None, tail: int | None): - """ - Take a table repr (ascii or html format) and truncate it to show only the - first and/or last N rows. - """ - if head is None and tail is None: - return table_repr + for line in head_rows: + lines.append(" | ".join(str(v).ljust(width) for v in line.values())) - truncated_repr = None + if ellipsis: + ellipsis_row = " | ".join("...".ljust(width) for _ in headers) + lines.append(ellipsis_row) - if "" in table_repr: - truncated_repr = _truncate_html_table(table_repr, head=head, tail=tail) - elif "---+---" in table_repr: - truncated_repr = _truncate_lines(table_repr, headers=2, head=head, tail=tail) + for line in tail_rows: + lines.append(" | ".join(str(v).ljust(width) for v in line.values())) - # if we didn't detect either an ascii or html table, or the html regex - # didn't match as expected, fall back to a simple truncation of lines using - # line breaks. - if truncated_repr is None: - truncated_repr = _truncate_lines(table_repr, head=head, tail=tail) - return truncated_repr + return "\n".join(line.strip() for line in lines) DISPLAY_RENDERERS = { diff --git a/tests/unit/test_debugger.py b/tests/unit/test_debugger.py index 2a3e97496..b04d88c39 100644 --- a/tests/unit/test_debugger.py +++ b/tests/unit/test_debugger.py @@ -9,7 +9,6 @@ activate_debug_context, elements_are_related_series, related_patient_columns_to_records, - render, ) from ehrql.query_engines.in_memory_database import PatientColumn from ehrql.tables import EventFrame, PatientFrame, Series, table @@ -21,15 +20,29 @@ def date_serializer(obj): raise TypeError("Type not serializable") # pragma: no cover +def json_render_function(sequence, head=0, tail=0): + """Render as JSON, useful for testing.""" + return json.dumps(sequence, indent=4, default=date_serializer) + + def test_show(capsys): expected_output = textwrap.dedent( """ - Show line 32: + Show line 3: """ ).strip() - show("Hello") + exec( + textwrap.dedent( + """ + # line 2 + show("Hello") + # line 4 + """ + ) + ) + captured = capsys.readouterr() assert captured.err.strip().startswith(expected_output), captured.err @@ -37,12 +50,21 @@ def test_show(capsys): def test_show_with_label(capsys): expected_output = textwrap.dedent( """ - Show line 45: Number + Show line 3: Number """ ).strip() - show(14, label="Number") + exec( + textwrap.dedent( + """ + # line 2 + show(14, label="Number") + # line 4 + """ + ) + ) + captured = capsys.readouterr() assert captured.err.strip().startswith(expected_output), captured.err @@ -50,7 +72,7 @@ def test_show_with_label(capsys): def test_show_fails_for_non_ehrql_object(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: value, + render_function=json_render_function, ): with pytest.raises(TypeError): show("Hello") @@ -81,48 +103,6 @@ def test_related_patient_columns_to_records_full_join(): assert r == r_expected -def test_render_formatted_table(): - expected_output = textwrap.dedent( - """ - patient_id | value - ------------------+------------------ - 1 | 101 - 2 | 201 - """ - ).strip() - - c = PatientColumn.parse( - """ - 1 | 101 - 2 | 201 - """ - ) - assert render(c).strip() == expected_output - - -def test_render_truncated_table(): - expected_output = textwrap.dedent( - """ - patient_id | value - ------------------+------------------ - 1 | 101 - ... | ... - 4 | 401 - """ - ).strip() - - c = PatientColumn.parse( - """ - 1 | 101 - 2 | 201 - 3 | 301 - 4 | 401 - """ - ) - - assert render(c, head=1, tail=1) == expected_output - - @table class patients(PatientFrame): date_of_birth = Series(date) @@ -177,14 +157,14 @@ def dummy_tables_path(tmp_path_factory): [ { "patient_id": 1, - "date_of_birth": date(1970, 1, 1), + "date_of_birth": "1970-01-01", "date_of_death": "", "sex": "male", }, { "patient_id": 2, - "date_of_birth": date(1980, 1, 1), - "date_of_death": date(2020, 1, 1), + "date_of_birth": "1980-01-01", + "date_of_death": "2020-01-01", "sex": "female", }, ], @@ -192,8 +172,8 @@ def dummy_tables_path(tmp_path_factory): ( patients.date_of_birth, [ - {"patient_id": 1, "value": date(1970, 1, 1)}, - {"patient_id": 2, "value": date(1980, 1, 1)}, + {"patient_id": 1, "value": "1970-01-01"}, + {"patient_id": 2, "value": "1980-01-01"}, ], ), ( @@ -203,12 +183,12 @@ def dummy_tables_path(tmp_path_factory): dod=patients.date_of_death, ), [ - {"patient_id": 1, "dob": date(1970, 1, 1), "count": 2, "dod": ""}, + {"patient_id": 1, "dob": "1970-01-01", "count": 2, "dod": ""}, { "patient_id": 2, - "dob": date(1980, 1, 1), + "dob": "1980-01-01", "count": 1, - "dod": date(2020, 1, 1), + "dod": "2020-01-01", }, ], ), @@ -217,9 +197,9 @@ def dummy_tables_path(tmp_path_factory): def test_activate_debug_context(dummy_tables_path, expression, contents): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: repr(list(value)), - ): - assert repr(expression) == repr(contents) + render_function=json_render_function, + ) as ctx: + assert json.loads(ctx.render(expression)) == contents @pytest.mark.parametrize( @@ -237,14 +217,12 @@ def test_elements_are_related_series(elements, expected): assert elements_are_related_series(elements) == expected -def test_repr_related_patient_series(dummy_tables_path): +def test_render_related_patient_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps( - list(value), indent=4, default=date_serializer - ), - ): - rendered = render( + render_function=json_render_function, + ) as ctx: + rendered = ctx.render( patients.date_of_birth, patients.sex, events.count_for_patient(), @@ -268,14 +246,12 @@ def test_repr_related_patient_series(dummy_tables_path): ] -def test_repr_related_event_series(dummy_tables_path): +def test_render_related_event_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps( - list(value), indent=4, default=date_serializer - ), - ): - rendered = render(events.date, events.code, events.test_result) + render_function=json_render_function, + ) as ctx: + rendered = ctx.render(events.date, events.code, events.test_result) assert json.loads(rendered) == [ { "patient_id": 1, @@ -301,12 +277,12 @@ def test_repr_related_event_series(dummy_tables_path): ] -def test_repr_date_difference(dummy_tables_path): +def test_render_date_difference(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps(list(value), indent=4), - ): - rendered = render(patients.date_of_death - events.date) + render_function=json_render_function, + ) as ctx: + rendered = ctx.render(patients.date_of_death - events.date) assert json.loads(rendered) == [ {"patient_id": 1, "row_id": 1, "value": ""}, {"patient_id": 1, "row_id": 2, "value": ""}, @@ -314,12 +290,12 @@ def test_repr_date_difference(dummy_tables_path): ] -def test_repr_related_date_difference_patient_series(dummy_tables_path): +def test_render_related_date_difference_patient_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps(list(value), indent=4), - ): - rendered = render( + render_function=json_render_function, + ) as ctx: + rendered = ctx.render( "2024-01-01" - patients.date_of_birth, patients.sex, ) @@ -329,12 +305,12 @@ def test_repr_related_date_difference_patient_series(dummy_tables_path): ] -def test_repr_related_date_difference_event_series(dummy_tables_path): +def test_render_related_date_difference_event_series(dummy_tables_path): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: json.dumps(list(value), indent=4), - ): - rendered = render( + render_function=json_render_function, + ) as ctx: + rendered = ctx.render( events.date - patients.date_of_birth, events.code, ) @@ -357,7 +333,7 @@ def test_repr_related_date_difference_event_series(dummy_tables_path): def test_show_fails_for_mismatched_inputs(example_input): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: value, + render_function=json_render_function, ): with pytest.raises(TypeError): assert show(*example_input) @@ -378,7 +354,7 @@ def test_show_does_not_raise_error_for_series_from_same_domain( ): with activate_debug_context( dummy_tables_path=dummy_tables_path, - render_function=lambda value: value, + render_function=json_render_function, ): show(example_input[0], *example_input[1:]) @@ -386,16 +362,20 @@ def test_show_does_not_raise_error_for_series_from_same_domain( def test_show_not_run_outside_debug_context(capsys): expected_output = textwrap.dedent( """ - Show line 394: + Show line 3: - show() ignored because we're not running in debug mode """ ).strip() - show(patients.date_of_birth, patients.sex) + exec( + textwrap.dedent( + """ + # line 2 + show(patients.date_of_birth, patients.sex) + # line 4 + """ + ) + ) + captured = capsys.readouterr() assert captured.err.strip() == expected_output, captured.err - - -def test_render_multiple_without_repr_related_errors(): - with pytest.raises(TypeError): - render("hello", "goodbye") diff --git a/tests/unit/test_quiz.py b/tests/unit/test_quiz.py index 7989be975..6517757c3 100644 --- a/tests/unit/test_quiz.py +++ b/tests/unit/test_quiz.py @@ -339,9 +339,9 @@ def test_set_dummy_tables_path_in_debug_context(): ): questions = quiz.Questions() questions.set_dummy_tables_path("bar") - assert debugger.DEBUG_QUERY_ENGINE.dsn.name == "bar" + assert debugger.DEBUG_CONTEXT.query_engine.dsn.name == "bar" # This should be unset outside of the context manager - assert debugger.DEBUG_QUERY_ENGINE is None + assert debugger.DEBUG_CONTEXT is None def test_hint(capfd): diff --git a/tests/unit/test_renderers.py b/tests/unit/test_renderers.py index 81f87013a..c81eb96c5 100644 --- a/tests/unit/test_renderers.py +++ b/tests/unit/test_renderers.py @@ -7,7 +7,7 @@ PatientColumn, PatientTable, ) -from ehrql.renderers import DISPLAY_RENDERERS, truncate_table +from ehrql.renderers import DISPLAY_RENDERERS TABLE = PatientTable.parse( @@ -41,7 +41,7 @@ def test_render_table(render_format): "" "
" "" - "" + "" "" "" "" @@ -54,7 +54,7 @@ def test_render_table(render_format): "" ), } - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()).strip() + rendered = DISPLAY_RENDERERS[render_format](list(TABLE.to_records())).strip() assert rendered == expected_output[render_format], rendered @@ -73,7 +73,7 @@ def test_render_column(render_format): "" "
patient_idi1i2
patient_idi1i2
1101111
" "" - "" + "" "" "" "" @@ -90,7 +90,7 @@ def test_render_column(render_format): 2 | 201 """ ) - rendered = DISPLAY_RENDERERS[render_format](c.to_records()).strip() + rendered = DISPLAY_RENDERERS[render_format](list(c.to_records())).strip() assert rendered == expected_output[render_format], rendered @@ -110,20 +110,19 @@ def test_render_table_head(render_format): "" "
patient_idvalue
patient_idvalue
1101
" "" - "" + "" "" "" "" "" - "" + "" "" "
patient_idi1i2
patient_idi1i2
1101111
2201211
.........
" "" ), } - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()) - truncated = truncate_table(rendered, head=2, tail=None) + truncated = DISPLAY_RENDERERS[render_format](list(TABLE.to_records()), head=2) assert truncated == expected_output[render_format], truncated @@ -143,10 +142,10 @@ def test_render_table_tail(render_format): "" "" "" - "" + "" "" "" - "" + "" "" "" "" @@ -155,8 +154,7 @@ def test_render_table_tail(render_format): ), } - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()) - truncated = truncate_table(rendered, head=None, tail=2) + truncated = DISPLAY_RENDERERS[render_format](list(TABLE.to_records()), tail=2) assert truncated == expected_output[render_format], truncated @@ -178,12 +176,12 @@ def test_render_table_head_and_tail(render_format): "" "
patient_idi1i2
patient_idi1i2
.........
4401411
5501511
" "" - "" + "" "" "" "" "" - "" + "" "" "" "" @@ -192,16 +190,15 @@ def test_render_table_head_and_tail(render_format): ), } - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()) - truncated = truncate_table(rendered, head=2, tail=2) + truncated = DISPLAY_RENDERERS[render_format]( + list(TABLE.to_records()), head=2, tail=2 + ) assert truncated == expected_output[render_format], truncated @pytest.mark.parametrize( "render_format,head_tail", - list( - product(["ascii", "html"], [(None, None), (2, 3), (5, None), (None, 6), (3, 3)]) - ), + list(product(["ascii"], [(0, 0), (2, 3), (5, 0), (0, 6), (3, 3)])), ) def test_render_table_bad_head_tail(render_format, head_tail): expected_output = { @@ -220,7 +217,7 @@ def test_render_table_bad_head_tail(render_format, head_tail): "" "
patient_idi1i2
patient_idi1i2
1101111
2201211
.........
4401411
5501511
" "" - "" + "" "" "" "" @@ -234,41 +231,7 @@ def test_render_table_bad_head_tail(render_format, head_tail): ), } head, tail = head_tail - rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()) - truncated = truncate_table(rendered, head=head, tail=tail).strip() - assert truncated == expected_output[render_format], (truncated, head, tail) - - -def test_render_head_and_tail_not_a_table(): - expected_output = textwrap.dedent( - """ - a - b - ... - d - e - """ - ).strip() - - input_string = "\n".join(["a", "b", "c", "d", "e"]) - truncated = truncate_table(input_string, head=2, tail=2).strip() - assert truncated == expected_output, truncated - - -def test_truncate_table_bad_html(): - # If we can't parse something that looks like an html - # table as expected, we fall back to the basic line truncator - bad_html = ( - "
patient_idi1i2
patient_idi1i2
1101111
\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "
patient_idi1i2
" + truncated = DISPLAY_RENDERERS[render_format]( + list(TABLE.to_records()), head=head, tail=tail ) - - expected = "\n\n..." - - truncated = truncate_table(bad_html, head=2, tail=None) - assert truncated == expected, truncated + assert truncated == expected_output[render_format], (truncated, head, tail)