diff --git a/ehrql/debugger.py b/ehrql/debugger.py
index 03b950fb4..000e92fb4 100644
--- a/ehrql/debugger.py
+++ b/ehrql/debugger.py
@@ -1,6 +1,8 @@
import contextlib
import inspect
import sys
+import typing
+from dataclasses import dataclass
from ehrql.query_engines.debug import DebugQueryEngine
from ehrql.query_engines.in_memory_database import (
@@ -8,13 +10,12 @@
PatientColumn,
render_value,
)
-from ehrql.query_language import BaseFrame, BaseSeries, Dataset, DateDifference
+from ehrql.query_language import BaseSeries, Dataset, DateDifference
from ehrql.query_model import nodes as qm
-from ehrql.renderers import truncate_table
from ehrql.utils.docs_utils import exclude_from_docs
-DEBUG_QUERY_ENGINE = None
+DEBUG_CONTEXT = None
@exclude_from_docs
@@ -54,26 +55,19 @@ def show(
label = f" {label}" if label else ""
print(f"Show line {line_no}:{label}", file=sys.stderr)
- if DEBUG_QUERY_ENGINE is None:
+ if DEBUG_CONTEXT is None:
# We're not in debug mode so ignore
print(
" - show() ignored because we're not running in debug mode", file=sys.stderr
)
return
- elements = [element, *other_elements]
-
- is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0
- is_multiple_series_same_domain = len(elements) > 1 and elements_are_related_series(
- elements
- )
-
- # We throw an error unless show() is called with:
- # 1. A single ehrql object (series, frame, dataset)
- # 2. Multiple series drawn from the same domain
- if is_single_ehrql_object | is_multiple_series_same_domain:
- print(render(element, *other_elements, head=head, tail=tail), file=sys.stderr)
- else:
+ try:
+ rendered = DEBUG_CONTEXT.render(
+ element, *other_elements, head=head or 0, tail=tail or 0
+ )
+ except TypeError:
+ # raise more helpful show() specific error
raise TypeError(
"\n\nshow() can be used in two ways. Either:\n"
" - call it with a single argument such as:\n"
@@ -89,73 +83,63 @@ def show(
"clinical_events.numeric_values)"
)
-
-def render(
- element,
- *other_elements,
- head: int | None = None,
- tail: int | None = None,
-):
- if len(other_elements) == 0:
- # single ehrql element so we just display it
- element_repr = repr(element)
- elif hasattr(element, "__repr_related__"):
- # multiple ehrql series so we combine them
- element_repr = element.__repr_related__(*other_elements)
- else:
- raise TypeError(
- "When render() is called on multiple series, the first argument must have a __repr_related__ attribute."
- )
-
- if head or tail:
- element_repr = truncate_table(element_repr, head, tail)
- return element_repr
+ print(rendered, file=sys.stderr)
@contextlib.contextmanager
def activate_debug_context(*, dummy_tables_path, render_function):
- global DEBUG_QUERY_ENGINE
-
- # Record original methods
- BaseFrame__repr__ = BaseFrame.__repr__
- BaseSeries__repr__ = BaseSeries.__repr__
- DateDifference__repr__ = DateDifference.__repr__
- Dataset__repr__ = Dataset.__repr__
-
- query_engine = DebugQueryEngine(dummy_tables_path)
- DEBUG_QUERY_ENGINE = query_engine
-
- def repr_ehrql(obj, template=None):
- return query_engine.evaluate(obj)._render_(render_function)
-
- def repr_related_ehrql(*series_args):
- return render_function(
- related_columns_to_records(
- [query_engine.evaluate(series) for series in series_args]
- )
- )
-
- # Temporarily overwrite __repr__ methods to display contents
- BaseFrame.__repr__ = repr_ehrql
- BaseSeries.__repr__ = repr_ehrql
- DateDifference.__repr__ = repr_ehrql
- Dataset.__repr__ = repr_ehrql
-
- # Add additional method for displaying related series together
- BaseSeries.__repr_related__ = repr_related_ehrql
- DateDifference.__repr_related__ = repr_related_ehrql
-
+ global DEBUG_CONTEXT
+ DEBUG_CONTEXT = DebugContext(
+ query_engine=DebugQueryEngine(dummy_tables_path),
+ render_function=render_function,
+ )
try:
- yield
+ yield DEBUG_CONTEXT # yield the context for testing convenience
finally:
- DEBUG_QUERY_ENGINE = None
- # Restore original methods
- BaseFrame.__repr__ = BaseFrame__repr__
- BaseSeries.__repr__ = BaseSeries__repr__
- DateDifference.__repr__ = DateDifference__repr__
- Dataset.__repr__ = Dataset__repr__
- del BaseSeries.__repr_related__
- del DateDifference.__repr_related__
+ DEBUG_CONTEXT = None
+
+
+@dataclass
+class DebugContext:
+ """Record the currently active query engine and render function."""
+
+ query_engine: DebugQueryEngine
+ render_function: typing.Callable
+
+ def render(
+ self,
+ element,
+ *other_elements,
+ head: int = 0,
+ tail: int = 0,
+ ):
+ elements = [element, *other_elements]
+
+ is_single_ehrql_object = is_ehrql_object(element) and len(other_elements) == 0
+ is_multiple_series_same_domain = len(
+ elements
+ ) > 1 and elements_are_related_series(elements)
+
+ # We throw an error unless called with:
+ # 1. A single ehrql object (series, frame, dataset)
+ # 2. Multiple series drawn from the same domain
+ if not (is_single_ehrql_object | is_multiple_series_same_domain):
+ raise TypeError("All elements must be in the same domain")
+
+ if len(other_elements) == 0:
+ # single ehrql element so we just display it
+ evaluated = list(
+ self.query_engine.evaluate(element).to_records(convert_null=True)
+ )
+ else:
+ # multiple ehrql series so we combine them
+ evaluated = list(
+ related_columns_to_records(
+ [self.query_engine.evaluate(element) for element in elements]
+ )
+ )
+
+ return self.render_function(evaluated, head, tail)
def elements_are_related_series(elements):
diff --git a/ehrql/loaders.py b/ehrql/loaders.py
index 1ea4536e4..4c4cea2bb 100644
--- a/ehrql/loaders.py
+++ b/ehrql/loaders.py
@@ -324,6 +324,12 @@ def load_definition_unsafe(
def load_module(module_path, user_args=()):
+ """Load a module containing an ehrql dataset definition.
+
+ Renders the serialized query model as json to the callers stdout, and any
+ other output to stderr.
+ """
+
# Taken from the official recipe for importing a module from a file path:
# https://docs.python.org/3.9/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_path.stem, module_path)
diff --git a/ehrql/query_engines/debug.py b/ehrql/query_engines/debug.py
index 67b57cc85..e8130eca2 100644
--- a/ehrql/query_engines/debug.py
+++ b/ehrql/query_engines/debug.py
@@ -57,5 +57,5 @@ class EmptyDataset:
"""This class exists to render something nice when a user tries to inspect a dataset
with no columns with debugger.show()."""
- def _render_(self, render_fn):
- return render_fn([{"patient_id": ""}])
+ def to_records(self, convert_null=False):
+ return [{"patient_id": ""}]
diff --git a/ehrql/query_engines/in_memory_database.py b/ehrql/query_engines/in_memory_database.py
index 01f9a6314..fa6cd7697 100644
--- a/ehrql/query_engines/in_memory_database.py
+++ b/ehrql/query_engines/in_memory_database.py
@@ -93,7 +93,7 @@ def __repr__(self):
return self._render_(DISPLAY_RENDERERS["ascii"])
def _render_(self, render_fn):
- return render_fn(self.to_records(convert_null=True))
+ return render_fn(list(self.to_records(convert_null=True)))
def __getitem__(self, name):
return self.name_to_col[name]
@@ -179,7 +179,7 @@ def __repr__(self):
return self._render_(DISPLAY_RENDERERS["ascii"])
def _render_(self, render_fn):
- return render_fn(self.to_records(convert_null=True))
+ return render_fn(list(self.to_records(convert_null=True)))
def __getitem__(self, name):
return self.name_to_col[name]
@@ -257,7 +257,7 @@ def __repr__(self):
return self._render_(DISPLAY_RENDERERS["ascii"])
def _render_(self, render_fn):
- return render_fn(self.to_records(convert_null=True))
+ return render_fn(list(self.to_records(convert_null=True)))
def __getitem__(self, patient):
return self.patient_to_value.get(patient, self.default)
@@ -318,7 +318,7 @@ def __repr__(self):
return self._render_(DISPLAY_RENDERERS["ascii"])
def _render_(self, render_fn):
- return render_fn(self.to_records(convert_null=True))
+ return render_fn(list(self.to_records(convert_null=True)))
def __getitem__(self, patient):
return self.patient_to_rows.get(patient, Rows({}))
diff --git a/ehrql/quiz.py b/ehrql/quiz.py
index beb9355cb..5a5a2fa60 100644
--- a/ehrql/quiz.py
+++ b/ehrql/quiz.py
@@ -276,8 +276,8 @@ def set_dummy_tables_path(self, path):
# itself, regardless of how the quiz script was invoked. This kind of global
# nastiness seems tolerable in the context of the quiz, and worth it for
# avoiding potential confusion.
- if ehrql.debugger.DEBUG_QUERY_ENGINE is not None:
- ehrql.debugger.DEBUG_QUERY_ENGINE.dsn = path
+ if ehrql.debugger.DEBUG_CONTEXT is not None:
+ ehrql.debugger.DEBUG_CONTEXT.query_engine.dsn = path
def __setitem__(self, index, question):
question.index = index
diff --git a/ehrql/renderers.py b/ehrql/renderers.py
index 7fb53d765..0114e9be9 100644
--- a/ehrql/renderers.py
+++ b/ehrql/renderers.py
@@ -1,134 +1,79 @@
-import re
+import typing
START_MARKER = ""
END_MARKER = ""
-def records_to_html_table(records: list[dict]):
- rows = []
- headers_written = False
-
- for record in records:
- if not headers_written:
- headers = "".join([f"
{header} | " for header in record.keys()])
- headers_written = True
- row = "".join(f"{val} | " for val in record.values())
- rows.append(f"{row}
")
- rows = "".join(rows)
-
- return f"{START_MARKER}{END_MARKER}"
-
-
-def records_to_ascii_table(records: list[dict]):
- width = 17
- lines = []
- headers_written = False
- for record in records:
- if not headers_written:
- lines.append(" | ".join(name.ljust(width) for name in record.keys()))
- lines.append("-+-".join("-" * width for _ in record.keys()))
- headers_written = True
- lines.append(" | ".join(str(value).ljust(width) for value in record.values()))
- return "\n".join(line.strip() for line in lines) + "\n"
-
-
-def _truncate_html_table(table_repr: str, head: int | None, tail: int | None):
- """
- Truncate an html table string to the first/last N rows, with a row of ...
- values to indicate where it's been truncated
- """
- regex = re.compile(
- rf"(?P^{START_MARKER}.*)(?P.*<\/tr>)(?P<\/tbody>.*<\/table>{END_MARKER})"
- )
- match = regex.match(table_repr)
- if match is None:
- # if we can't parse the table, return None and let the fallback handle it
- return
-
- start = match.group("start")
- rows = match.group("rows")
- end = match.group("end")
+def headtail(sequence: typing.Sequence, head: int = 0, tail: int = 0):
+ if head == tail == 0:
+ return sequence, False, []
# Check we have enough rows to truncate to head and tail
- if rows.count("") <= ((head or 0) + (tail or 0)):
- return table_repr
+ if len(sequence) <= head + tail:
+ return sequence, False, []
- # split row on
tokens, remove any empty strings (this will lose the first
of
- # each row, but we'll add it in again later)
- rows = [row for row in rows.split("
") if row]
- # compose an "ellipsis row" to mark the place of truncated rows
- td_count = rows[0].count("")
- ellipsis_row = f"{' | ... | ' * td_count}
"
+ head_rows = []
+ tail_rows = []
- # Build the list of rows we need to include, with ellipsis rows where necessary
- truncated_rows = []
+ if head:
+ head_rows = sequence[:head]
- head_rows = rows[:head] if head is not None else [ellipsis_row]
- truncated_rows.extend(head_rows)
+ if tail:
+ tail_rows = sequence[-tail:]
- if head is not None and tail is not None:
- truncated_rows.append(ellipsis_row)
+ return head_rows, True, tail_rows
- tail_rows = rows[-tail:] if tail is not None else [ellipsis_row]
- truncated_rows.extend(tail_rows)
- # re-join the truncated rows with
- truncated_rows = "
" + "
".join(truncated_rows)
- return start + truncated_rows + end
+def records_to_html_table(records: list[dict], head: int = 0, tail: int = 0):
+ rows = []
+ headers = (
+ "".join([f"{header} | " for header in records[0].keys()])
+ if records
+ else ""
+ )
+ head_rows, ellipsis, tail_rows = headtail(records, head, tail)
+ def html_row(row):
+ columns = "".join(f"{v} | " for v in row.values())
+ return f"
{columns}
"
-def _truncate_lines(
- table_repr: str, headers: int = 0, head: int | None = None, tail: int | None = None
-):
- table_rows = [row for row in table_repr.split("\n") if row]
+ for row in head_rows:
+ rows.append(html_row(row))
- # Check we have enough rows to truncate to head and tail
- if len(table_rows) <= (headers + (head or 0) + (tail or 0)):
- return table_repr
+ if ellipsis:
+ ellipsis_columns = "… | " * len(records[0])
+ rows.append(f"{ellipsis_columns}
")
- # compose an "ellipsis row" to mark the place of truncated rows
- cell_count = table_rows[0].count("|") + 1
- ellipsis_row = " | ".join("...".ljust(17) for i in range(cell_count)).strip()
+ for row in tail_rows:
+ rows.append(html_row(row))
- # Build the list of rows we need to include, with ellipsis rows where necessary
- truncated_rows = table_rows[:headers]
+ html_rows = "".join(rows)
- head_rows = (
- table_rows[headers : headers + head] if head is not None else [ellipsis_row]
- )
- truncated_rows.extend(head_rows)
+ return f"{START_MARKER}{END_MARKER}"
- if head is not None and tail is not None:
- truncated_rows.append(ellipsis_row)
- tail_rows = table_rows[-tail:] if tail is not None else [ellipsis_row]
- truncated_rows.extend(tail_rows)
+def records_to_ascii_table(records: list[dict], head: int = 0, tail: int = 0):
+ head_rows, ellipsis, tail_rows = headtail(records, head, tail)
+ width = 17
+ lines = []
- return "\n".join(truncated_rows)
+ headers = records[0].keys() if records else []
+ lines.append(" | ".join(name.ljust(width) for name in headers))
+ lines.append("-+-".join("-" * width for _ in headers))
-def truncate_table(table_repr: str, head: int | None, tail: int | None):
- """
- Take a table repr (ascii or html format) and truncate it to show only the
- first and/or last N rows.
- """
- if head is None and tail is None:
- return table_repr
+ for line in head_rows:
+ lines.append(" | ".join(str(v).ljust(width) for v in line.values()))
- truncated_repr = None
+ if ellipsis:
+ ellipsis_row = " | ".join("...".ljust(width) for _ in headers)
+ lines.append(ellipsis_row)
- if "" in table_repr:
- truncated_repr = _truncate_html_table(table_repr, head=head, tail=tail)
- elif "---+---" in table_repr:
- truncated_repr = _truncate_lines(table_repr, headers=2, head=head, tail=tail)
+ for line in tail_rows:
+ lines.append(" | ".join(str(v).ljust(width) for v in line.values()))
- # if we didn't detect either an ascii or html table, or the html regex
- # didn't match as expected, fall back to a simple truncation of lines using
- # line breaks.
- if truncated_repr is None:
- truncated_repr = _truncate_lines(table_repr, head=head, tail=tail)
- return truncated_repr
+ return "\n".join(line.strip() for line in lines)
DISPLAY_RENDERERS = {
diff --git a/tests/unit/test_debugger.py b/tests/unit/test_debugger.py
index 2a3e97496..b04d88c39 100644
--- a/tests/unit/test_debugger.py
+++ b/tests/unit/test_debugger.py
@@ -9,7 +9,6 @@
activate_debug_context,
elements_are_related_series,
related_patient_columns_to_records,
- render,
)
from ehrql.query_engines.in_memory_database import PatientColumn
from ehrql.tables import EventFrame, PatientFrame, Series, table
@@ -21,15 +20,29 @@ def date_serializer(obj):
raise TypeError("Type not serializable") # pragma: no cover
+def json_render_function(sequence, head=0, tail=0):
+ """Render as JSON, useful for testing."""
+ return json.dumps(sequence, indent=4, default=date_serializer)
+
+
def test_show(capsys):
expected_output = textwrap.dedent(
"""
- Show line 32:
+ Show line 3:
"""
).strip()
- show("Hello")
+ exec(
+ textwrap.dedent(
+ """
+ # line 2
+ show("Hello")
+ # line 4
+ """
+ )
+ )
+
captured = capsys.readouterr()
assert captured.err.strip().startswith(expected_output), captured.err
@@ -37,12 +50,21 @@ def test_show(capsys):
def test_show_with_label(capsys):
expected_output = textwrap.dedent(
"""
- Show line 45: Number
+ Show line 3: Number
"""
).strip()
- show(14, label="Number")
+ exec(
+ textwrap.dedent(
+ """
+ # line 2
+ show(14, label="Number")
+ # line 4
+ """
+ )
+ )
+
captured = capsys.readouterr()
assert captured.err.strip().startswith(expected_output), captured.err
@@ -50,7 +72,7 @@ def test_show_with_label(capsys):
def test_show_fails_for_non_ehrql_object(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
- render_function=lambda value: value,
+ render_function=json_render_function,
):
with pytest.raises(TypeError):
show("Hello")
@@ -81,48 +103,6 @@ def test_related_patient_columns_to_records_full_join():
assert r == r_expected
-def test_render_formatted_table():
- expected_output = textwrap.dedent(
- """
- patient_id | value
- ------------------+------------------
- 1 | 101
- 2 | 201
- """
- ).strip()
-
- c = PatientColumn.parse(
- """
- 1 | 101
- 2 | 201
- """
- )
- assert render(c).strip() == expected_output
-
-
-def test_render_truncated_table():
- expected_output = textwrap.dedent(
- """
- patient_id | value
- ------------------+------------------
- 1 | 101
- ... | ...
- 4 | 401
- """
- ).strip()
-
- c = PatientColumn.parse(
- """
- 1 | 101
- 2 | 201
- 3 | 301
- 4 | 401
- """
- )
-
- assert render(c, head=1, tail=1) == expected_output
-
-
@table
class patients(PatientFrame):
date_of_birth = Series(date)
@@ -177,14 +157,14 @@ def dummy_tables_path(tmp_path_factory):
[
{
"patient_id": 1,
- "date_of_birth": date(1970, 1, 1),
+ "date_of_birth": "1970-01-01",
"date_of_death": "",
"sex": "male",
},
{
"patient_id": 2,
- "date_of_birth": date(1980, 1, 1),
- "date_of_death": date(2020, 1, 1),
+ "date_of_birth": "1980-01-01",
+ "date_of_death": "2020-01-01",
"sex": "female",
},
],
@@ -192,8 +172,8 @@ def dummy_tables_path(tmp_path_factory):
(
patients.date_of_birth,
[
- {"patient_id": 1, "value": date(1970, 1, 1)},
- {"patient_id": 2, "value": date(1980, 1, 1)},
+ {"patient_id": 1, "value": "1970-01-01"},
+ {"patient_id": 2, "value": "1980-01-01"},
],
),
(
@@ -203,12 +183,12 @@ def dummy_tables_path(tmp_path_factory):
dod=patients.date_of_death,
),
[
- {"patient_id": 1, "dob": date(1970, 1, 1), "count": 2, "dod": ""},
+ {"patient_id": 1, "dob": "1970-01-01", "count": 2, "dod": ""},
{
"patient_id": 2,
- "dob": date(1980, 1, 1),
+ "dob": "1980-01-01",
"count": 1,
- "dod": date(2020, 1, 1),
+ "dod": "2020-01-01",
},
],
),
@@ -217,9 +197,9 @@ def dummy_tables_path(tmp_path_factory):
def test_activate_debug_context(dummy_tables_path, expression, contents):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
- render_function=lambda value: repr(list(value)),
- ):
- assert repr(expression) == repr(contents)
+ render_function=json_render_function,
+ ) as ctx:
+ assert json.loads(ctx.render(expression)) == contents
@pytest.mark.parametrize(
@@ -237,14 +217,12 @@ def test_elements_are_related_series(elements, expected):
assert elements_are_related_series(elements) == expected
-def test_repr_related_patient_series(dummy_tables_path):
+def test_render_related_patient_series(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
- render_function=lambda value: json.dumps(
- list(value), indent=4, default=date_serializer
- ),
- ):
- rendered = render(
+ render_function=json_render_function,
+ ) as ctx:
+ rendered = ctx.render(
patients.date_of_birth,
patients.sex,
events.count_for_patient(),
@@ -268,14 +246,12 @@ def test_repr_related_patient_series(dummy_tables_path):
]
-def test_repr_related_event_series(dummy_tables_path):
+def test_render_related_event_series(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
- render_function=lambda value: json.dumps(
- list(value), indent=4, default=date_serializer
- ),
- ):
- rendered = render(events.date, events.code, events.test_result)
+ render_function=json_render_function,
+ ) as ctx:
+ rendered = ctx.render(events.date, events.code, events.test_result)
assert json.loads(rendered) == [
{
"patient_id": 1,
@@ -301,12 +277,12 @@ def test_repr_related_event_series(dummy_tables_path):
]
-def test_repr_date_difference(dummy_tables_path):
+def test_render_date_difference(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
- render_function=lambda value: json.dumps(list(value), indent=4),
- ):
- rendered = render(patients.date_of_death - events.date)
+ render_function=json_render_function,
+ ) as ctx:
+ rendered = ctx.render(patients.date_of_death - events.date)
assert json.loads(rendered) == [
{"patient_id": 1, "row_id": 1, "value": ""},
{"patient_id": 1, "row_id": 2, "value": ""},
@@ -314,12 +290,12 @@ def test_repr_date_difference(dummy_tables_path):
]
-def test_repr_related_date_difference_patient_series(dummy_tables_path):
+def test_render_related_date_difference_patient_series(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
- render_function=lambda value: json.dumps(list(value), indent=4),
- ):
- rendered = render(
+ render_function=json_render_function,
+ ) as ctx:
+ rendered = ctx.render(
"2024-01-01" - patients.date_of_birth,
patients.sex,
)
@@ -329,12 +305,12 @@ def test_repr_related_date_difference_patient_series(dummy_tables_path):
]
-def test_repr_related_date_difference_event_series(dummy_tables_path):
+def test_render_related_date_difference_event_series(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
- render_function=lambda value: json.dumps(list(value), indent=4),
- ):
- rendered = render(
+ render_function=json_render_function,
+ ) as ctx:
+ rendered = ctx.render(
events.date - patients.date_of_birth,
events.code,
)
@@ -357,7 +333,7 @@ def test_repr_related_date_difference_event_series(dummy_tables_path):
def test_show_fails_for_mismatched_inputs(example_input):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
- render_function=lambda value: value,
+ render_function=json_render_function,
):
with pytest.raises(TypeError):
assert show(*example_input)
@@ -378,7 +354,7 @@ def test_show_does_not_raise_error_for_series_from_same_domain(
):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
- render_function=lambda value: value,
+ render_function=json_render_function,
):
show(example_input[0], *example_input[1:])
@@ -386,16 +362,20 @@ def test_show_does_not_raise_error_for_series_from_same_domain(
def test_show_not_run_outside_debug_context(capsys):
expected_output = textwrap.dedent(
"""
- Show line 394:
+ Show line 3:
- show() ignored because we're not running in debug mode
"""
).strip()
- show(patients.date_of_birth, patients.sex)
+ exec(
+ textwrap.dedent(
+ """
+ # line 2
+ show(patients.date_of_birth, patients.sex)
+ # line 4
+ """
+ )
+ )
+
captured = capsys.readouterr()
assert captured.err.strip() == expected_output, captured.err
-
-
-def test_render_multiple_without_repr_related_errors():
- with pytest.raises(TypeError):
- render("hello", "goodbye")
diff --git a/tests/unit/test_quiz.py b/tests/unit/test_quiz.py
index 7989be975..6517757c3 100644
--- a/tests/unit/test_quiz.py
+++ b/tests/unit/test_quiz.py
@@ -339,9 +339,9 @@ def test_set_dummy_tables_path_in_debug_context():
):
questions = quiz.Questions()
questions.set_dummy_tables_path("bar")
- assert debugger.DEBUG_QUERY_ENGINE.dsn.name == "bar"
+ assert debugger.DEBUG_CONTEXT.query_engine.dsn.name == "bar"
# This should be unset outside of the context manager
- assert debugger.DEBUG_QUERY_ENGINE is None
+ assert debugger.DEBUG_CONTEXT is None
def test_hint(capfd):
diff --git a/tests/unit/test_renderers.py b/tests/unit/test_renderers.py
index 81f87013a..c81eb96c5 100644
--- a/tests/unit/test_renderers.py
+++ b/tests/unit/test_renderers.py
@@ -7,7 +7,7 @@
PatientColumn,
PatientTable,
)
-from ehrql.renderers import DISPLAY_RENDERERS, truncate_table
+from ehrql.renderers import DISPLAY_RENDERERS
TABLE = PatientTable.parse(
@@ -41,7 +41,7 @@ def test_render_table(render_format):
""
""
""
- "patient_id | i1 | i2 | "
+ "patient_id | i1 | i2 |
"
""
""
"1 | 101 | 111 |
"
@@ -54,7 +54,7 @@ def test_render_table(render_format):
""
),
}
- rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records()).strip()
+ rendered = DISPLAY_RENDERERS[render_format](list(TABLE.to_records())).strip()
assert rendered == expected_output[render_format], rendered
@@ -73,7 +73,7 @@ def test_render_column(render_format):
""
""
""
- "patient_id | value | "
+ "patient_id | value |
"
""
""
"1 | 101 |
"
@@ -90,7 +90,7 @@ def test_render_column(render_format):
2 | 201
"""
)
- rendered = DISPLAY_RENDERERS[render_format](c.to_records()).strip()
+ rendered = DISPLAY_RENDERERS[render_format](list(c.to_records())).strip()
assert rendered == expected_output[render_format], rendered
@@ -110,20 +110,19 @@ def test_render_table_head(render_format):
""
""
""
- "patient_id | i1 | i2 | "
+ "patient_id | i1 | i2 |
"
""
""
"1 | 101 | 111 |
"
"2 | 201 | 211 |
"
- "... | ... | ... |
"
+ "… | … | … |
"
""
"
"
""
),
}
- rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records())
- truncated = truncate_table(rendered, head=2, tail=None)
+ truncated = DISPLAY_RENDERERS[render_format](list(TABLE.to_records()), head=2)
assert truncated == expected_output[render_format], truncated
@@ -143,10 +142,10 @@ def test_render_table_tail(render_format):
""
""
""
- "patient_id | i1 | i2 | "
+ "patient_id | i1 | i2 |
"
""
""
- "... | ... | ... |
"
+ "… | … | … |
"
"4 | 401 | 411 |
"
"5 | 501 | 511 |
"
""
@@ -155,8 +154,7 @@ def test_render_table_tail(render_format):
),
}
- rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records())
- truncated = truncate_table(rendered, head=None, tail=2)
+ truncated = DISPLAY_RENDERERS[render_format](list(TABLE.to_records()), tail=2)
assert truncated == expected_output[render_format], truncated
@@ -178,12 +176,12 @@ def test_render_table_head_and_tail(render_format):
""
""
""
- "patient_id | i1 | i2 | "
+ "patient_id | i1 | i2 |
"
""
""
"1 | 101 | 111 |
"
"2 | 201 | 211 |
"
- "... | ... | ... |
"
+ "… | … | … |
"
"4 | 401 | 411 |
"
"5 | 501 | 511 |
"
""
@@ -192,16 +190,15 @@ def test_render_table_head_and_tail(render_format):
),
}
- rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records())
- truncated = truncate_table(rendered, head=2, tail=2)
+ truncated = DISPLAY_RENDERERS[render_format](
+ list(TABLE.to_records()), head=2, tail=2
+ )
assert truncated == expected_output[render_format], truncated
@pytest.mark.parametrize(
"render_format,head_tail",
- list(
- product(["ascii", "html"], [(None, None), (2, 3), (5, None), (None, 6), (3, 3)])
- ),
+ list(product(["ascii"], [(0, 0), (2, 3), (5, 0), (0, 6), (3, 3)])),
)
def test_render_table_bad_head_tail(render_format, head_tail):
expected_output = {
@@ -220,7 +217,7 @@ def test_render_table_bad_head_tail(render_format, head_tail):
""
""
""
- "patient_id | i1 | i2 | "
+ "patient_id | i1 | i2 |
"
""
""
"1 | 101 | 111 |
"
@@ -234,41 +231,7 @@ def test_render_table_bad_head_tail(render_format, head_tail):
),
}
head, tail = head_tail
- rendered = DISPLAY_RENDERERS[render_format](TABLE.to_records())
- truncated = truncate_table(rendered, head=head, tail=tail).strip()
- assert truncated == expected_output[render_format], (truncated, head, tail)
-
-
-def test_render_head_and_tail_not_a_table():
- expected_output = textwrap.dedent(
- """
- a
- b
- ...
- d
- e
- """
- ).strip()
-
- input_string = "\n".join(["a", "b", "c", "d", "e"])
- truncated = truncate_table(input_string, head=2, tail=2).strip()
- assert truncated == expected_output, truncated
-
-
-def test_truncate_table_bad_html():
- # If we can't parse something that looks like an html
- # table as expected, we fall back to the basic line truncator
- bad_html = (
- "\n"
- "\n"
- "patient_id | i1 | i2 | \n"
- "\n"
- "\n"
- "\n"
- "
"
+ truncated = DISPLAY_RENDERERS[render_format](
+ list(TABLE.to_records()), head=head, tail=tail
)
-
- expected = "\n\n..."
-
- truncated = truncate_table(bad_html, head=2, tail=None)
- assert truncated == expected, truncated
+ assert truncated == expected_output[render_format], (truncated, head, tail)