Skip to content

Commit

Permalink
Make tests less cursed
Browse files Browse the repository at this point in the history
  • Loading branch information
inglesp committed Nov 25, 2024
1 parent ffd19db commit ef13d63
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 115 deletions.
6 changes: 5 additions & 1 deletion ehrql/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ def main(args, environ=None):
except DefinitionError as exc:
# Errors from definition files are already pre-formatted so we just write them
# directly to stderr and exit
print(str(exc), file=sys.stderr)
if kwargs.get("display_format") == "html":
msg = f"<pre>{exc}</pre>"
else:
msg = str(exc)
print(msg, file=sys.stderr)
sys.exit(1)
except FileValidationError as exc:
# Handle errors encountered while reading user-supplied data
Expand Down
16 changes: 12 additions & 4 deletions ehrql/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ def debug(
debug(<table>, head=5, tail=5)
"""
line_no = inspect.getframeinfo(sys._getframe(1))[1]
label = f" {label}" if label else ""
print(f"Debug line {line_no}:{label}", file=sys.stderr)
print(render(element, *other_elements, head=head, tail=tail), file=sys.stderr)


def render(
element,
*other_elements,
head: int | None = None,
tail: int | None = None,
):
elements = [element, *other_elements]

if hasattr(element, "__repr_related__") and elements_are_related_series(elements):
Expand All @@ -59,10 +70,7 @@ def debug(
element_reprs = [
truncate_table(el_repr, head, tail) for el_repr in element_reprs
]
label = f" {label}" if label else ""
print(f"Debug line {line_no}:{label}", file=sys.stderr)
for el_repr in element_reprs:
print(el_repr, file=sys.stderr)
return "\n".join(element_reprs)


@contextlib.contextmanager
Expand Down
161 changes: 51 additions & 110 deletions tests/unit/test_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from ehrql import create_dataset, debug
from ehrql.debugger import activate_debug_context, elements_are_related_series
from ehrql.debugger import activate_debug_context, elements_are_related_series, render
from ehrql.query_engines.in_memory_database import PatientColumn
from ehrql.tables import EventFrame, PatientFrame, Series, table

Expand All @@ -16,7 +16,7 @@ def date_serializer(obj):
raise TypeError("Type not serializable") # pragma: no cover


def test_show_string(capsys):
def test_debug(capsys):
expected_output = textwrap.dedent(
"""
Debug line 27:
Expand All @@ -29,53 +29,41 @@ def test_show_string(capsys):
assert captured.err.strip() == expected_output, captured.err


def test_show_int_variable(capsys):
def test_debug_with_label(capsys):
expected_output = textwrap.dedent(
"""
Debug line 41:
12
Debug line 40: Number
14
"""
).strip()

foo = 12
debug(foo)
debug(14, label="Number")
captured = capsys.readouterr()
assert captured.err.strip() == expected_output, captured.err


def test_show_multiple_variables(capsys):
expected_output = textwrap.dedent(
"""
Debug line 57:
12
'Hello'
"""
).strip()
def test_render_string():
assert render("Hello") == "'Hello'"

foo = 12
bar = "Hello"
debug(foo, bar)
captured = capsys.readouterr()
assert captured.err.strip() == expected_output, captured.err

def test_render_int_variable():
assert render(12) == "12"


def test_show_with_label(capsys):
def test_render_multiple_variables():
expected_output = textwrap.dedent(
"""
Debug line 70: Number
14
12
'Hello'
"""
).strip()

debug(14, label="Number")
captured = capsys.readouterr()
assert captured.err.strip() == expected_output, captured.err
assert render(12, "Hello") == expected_output


def test_show_formatted_table(capsys):
def test_render_formatted_table():
expected_output = textwrap.dedent(
"""
Debug line 92:
patient_id | value
------------------+------------------
1 | 101
Expand All @@ -89,15 +77,12 @@ def test_show_formatted_table(capsys):
2 | 201
"""
)
debug(c)
captured = capsys.readouterr()
assert captured.err.strip() == expected_output, captured.err
assert render(c).strip() == expected_output


def test_show_truncated_table(capsys):
def test_render_truncated_table():
expected_output = textwrap.dedent(
"""
Debug line 118:
patient_id | value
------------------+------------------
1 | 101
Expand All @@ -115,9 +100,7 @@ def test_show_truncated_table(capsys):
"""
)

debug(c, head=1, tail=1)
captured = capsys.readouterr()
assert captured.err.strip() == expected_output, captured.err
assert render(c, head=1, tail=1) == expected_output


@table
Expand Down Expand Up @@ -213,99 +196,57 @@ def test_elements_are_related_series(elements, expected):
assert elements_are_related_series(elements) == expected


def test_repr_related_patient_series(dummy_tables_path, capsys):
def test_repr_related_patient_series(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
render_function=lambda value: json.dumps(
list(value), indent=4, default=date_serializer
),
):
debug(
rendered = render(
patients.date_of_birth,
patients.sex,
events.count_for_patient(),
)
assert capsys.readouterr().err == textwrap.dedent(
"""\
Debug line 223:
[
{
"patient_id": 1,
"series_1": "1970-01-01",
"series_2": "male",
"series_3": 2
},
{
"patient_id": 2,
"series_1": "1980-01-01",
"series_2": "female",
"series_3": 1
}
]
"""
)


def test_repr_related_event_series(dummy_tables_path, capsys):
assert json.loads(rendered) == [
{
"patient_id": 1,
"series_1": "1970-01-01",
"series_2": "male",
"series_3": 2,
},
{
"patient_id": 2,
"series_1": "1980-01-01",
"series_2": "female",
"series_3": 1,
},
]


def test_repr_related_event_series(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
render_function=lambda value: json.dumps(
list(value), indent=4, default=date_serializer
),
):
debug(events.date, events.code)
assert capsys.readouterr().err == textwrap.dedent(
"""\
Debug line 256:
[
{
"patient_id": 1,
"row_id": 1,
"series_1": "2010-01-01",
"series_2": "abc"
},
{
"patient_id": 1,
"row_id": 2,
"series_1": "2020-01-01",
"series_2": "def"
},
{
"patient_id": 2,
"row_id": 3,
"series_1": "2005-01-01",
"series_2": "abc"
}
]
"""
)
rendered = render(events.date, events.code)
assert json.loads(rendered) == [
{"patient_id": 1, "row_id": 1, "series_1": "2010-01-01", "series_2": "abc"},
{"patient_id": 1, "row_id": 2, "series_1": "2020-01-01", "series_2": "def"},
{"patient_id": 2, "row_id": 3, "series_1": "2005-01-01", "series_2": "abc"},
]


def test_repr_date_difference(dummy_tables_path, capsys):
def test_repr_date_difference(dummy_tables_path):
with activate_debug_context(
dummy_tables_path=dummy_tables_path,
render_function=lambda value: json.dumps(list(value), indent=4),
):
debug(events.date - patients.date_of_birth)
assert capsys.readouterr().err == textwrap.dedent(
"""\
Debug line 289:
[
{
"patient_id": 1,
"row_id": 1,
"value": "14610 days"
},
{
"patient_id": 1,
"row_id": 2,
"value": "18262 days"
},
{
"patient_id": 2,
"row_id": 3,
"value": "9132 days"
}
]
"""
)
rendered = render(events.date - patients.date_of_birth)
assert json.loads(rendered) == [
{"patient_id": 1, "row_id": 1, "value": "14610 days"},
{"patient_id": 1, "row_id": 2, "value": "18262 days"},
{"patient_id": 2, "row_id": 3, "value": "9132 days"},
]

0 comments on commit ef13d63

Please sign in to comment.