Skip to content

Commit

Permalink
Merge pull request #2227 from opensafely-core/debug-updates
Browse files Browse the repository at this point in the history
Updates to the debug command
  • Loading branch information
rebkwok authored Nov 18, 2024
2 parents 443d544 + 4aaa128 commit 707fb3b
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 200 deletions.
2 changes: 0 additions & 2 deletions docs/includes/generated_docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -835,8 +835,6 @@ Path to directory of files (one per table) to use as dummy tables

Files may be in any supported format: `.arrow`, `.csv`, `.csv.gz`

This argument is ignored when running against real tables.

</div>

<div class="attr-heading" id="debug.display-format">
Expand Down
4 changes: 2 additions & 2 deletions ehrql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path

from ehrql.codes import codelist_from_csv
from ehrql.debug import show
from ehrql.debugger import debug
from ehrql.measures import INTERVAL, Measures, create_measures
from ehrql.query_language import (
Dataset,
Expand Down Expand Up @@ -32,10 +32,10 @@
"create_dataset",
"create_measures",
"days",
"debug",
"maximum_of",
"minimum_of",
"months",
"show",
"weeks",
"when",
"years",
Expand Down
16 changes: 15 additions & 1 deletion ehrql/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,21 @@ def add_debug_dataset_definition(subparsers, environ, user_args):
parser.set_defaults(environ=environ)
parser.set_defaults(user_args=user_args)
add_dataset_definition_file_argument(parser, environ)
add_dummy_tables_argument(parser, environ)

parser.add_argument(
"--dummy-tables",
help=strip_indent(
f"""
Path to directory of files (one per table) to use as dummy tables
(see [`create-dummy-tables`](#create-dummy-tables)).
Files may be in any supported format: {backtick_join(FILE_FORMATS)}
"""
),
type=existing_directory,
dest="dummy_tables_path",
)

add_display_renderer_argument(parser, environ)


Expand Down
18 changes: 4 additions & 14 deletions ehrql/debug.py → ehrql/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@exclude_from_docs
def show(
def debug(
element,
*other_elements,
label: str | None = None,
Expand All @@ -33,7 +33,7 @@ def show(
head and tail arguments can be combined, e.g. to show the first and last 5 lines of a table:
show(<table>, head=5, tail=5)
debug(<table>, head=5, tail=5)
"""
line_no = inspect.getframeinfo(sys._getframe(1))[1]
elements = [element, *other_elements]
Expand All @@ -48,19 +48,9 @@ def show(
print(el_repr, file=sys.stderr)


def stop(*, head: int | None = None, tail: int | None = None):
def stop():
"""
Stop loading the dataset definition and show the contents of the dataset at this point.
_head_<br>
Show only the first N lines of the dataset.
_tail_<br>
Show only the last N lines of the dataset.
head and tail arguments can be combined, e.g. to show the first and last 5 lines of the dataset:
stop(head=5, tail=5)
Stop loading the dataset definition at this point.
"""
line_no = inspect.getframeinfo(sys._getframe(1))[1]
print(f"Stopping at line {line_no}", file=sys.stderr)
37 changes: 4 additions & 33 deletions ehrql/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
import os
import re
import shutil
import sys
from contextlib import nullcontext
Expand Down Expand Up @@ -34,16 +33,13 @@
get_column_specs_for_measures,
get_measure_results,
)
from ehrql.query_engines.in_memory_database import truncate_records
from ehrql.query_engines.local_file import LocalFileQueryEngine
from ehrql.query_engines.sandbox import SandboxQueryEngine
from ehrql.query_engines.sqlite import SQLiteQueryEngine
from ehrql.query_model.column_specs import (
get_column_specs,
get_column_specs_from_schema,
)
from ehrql.query_model.graphs import graph_to_svg
from ehrql.renderers import DISPLAY_RENDERERS
from ehrql.serializer import serialize
from ehrql.utils.itertools_utils import eager_iterator
from ehrql.utils.sqlalchemy_query_utils import (
Expand Down Expand Up @@ -377,29 +373,15 @@ def debug_dataset_definition(
dummy_tables_path=None,
render_format="ascii",
):
# Rewrite the dataset definition up to the first stop() command and load it
# Loading it will execute any show() commands.
with NamedTemporaryFile(suffix=".py", dir=definition_file.parent) as tmpfile:
stop_args = _write_debug_definition_to_temp_file(
definition_file, Path(tmpfile.name)
)
_write_debug_definition_to_temp_file(definition_file, Path(tmpfile.name))

variable_definitions = load_debug_definition(
load_debug_definition(
tmpfile.name, user_args, environ, dummy_tables_path, render_format
)

query_engine = SandboxQueryEngine(dummy_tables_path)
column_specs = list(get_column_specs(variable_definitions))
results = eager_iterator(query_engine.get_results(variable_definitions))
records = [
{column_specs[i]: value for i, value in enumerate(result)} for result in results
]

if stop_args is not None:
records = truncate_records(records, **stop_args)

dataset_as_table = DISPLAY_RENDERERS[render_format](records)

print(dataset_as_table)


def _write_debug_definition_to_temp_file(definition_file, tmpfile):
# Read the dataset definition up to the first point that a
Expand All @@ -412,19 +394,8 @@ def _write_debug_definition_to_temp_file(definition_file, tmpfile):
if line.strip().startswith("stop("):
break

last_line = lines[-1]
stop_args = None
if "stop" in last_line:
head_match = re.match(r"^stop\(.*head=(?P<head>\d+)", last_line, flags=re.X)
tail_match = re.match(r"^stop\(.*tail=(?P<tail>\d+)", last_line, flags=re.X)
stop_args = {
"head": int(head_match.group("head")) if head_match else None,
"tail": int(tail_match.group("tail")) if tail_match else None,
}

lines = "".join(lines)
tmpfile.write_text(lines)
return stop_args


def test_connection(backend_class, url, environ):
Expand Down
24 changes: 0 additions & 24 deletions ehrql/query_engines/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,27 +516,3 @@ def parse_value(value):
def nulls_first_order(key):
# Usable as a key function to `sorted()` which sorts NULLs first
return (0 if key is None else 1, key)


def truncate_records(
records: list[dict], head: int | None = None, tail: int | None = None
):
"""
Truncate a list of records to the first/last N rows,
with a row of ... values to indicate where it's been truncated
These records will be passed to one of the display formatter functions
for rendering as ascii or html.
"""
if head is None and tail is None:
return records

if len(records) <= (head or 0) + (tail or 0):
return records

ellipsis_record = {k: "..." for k in records[0].keys()}
truncated_records = records[:head] if head is not None else [ellipsis_record]
if head and tail:
truncated_records.append(ellipsis_record)
truncated_records.extend(records[-tail:] if tail is not None else [ellipsis_record])
return truncated_records
3 changes: 3 additions & 0 deletions ehrql/query_engines/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ class EmptyDataset:

def __repr__(self):
return "Dataset()"

def _render_(self, render_fn):
return render_fn([{"patient_id": ""}])
6 changes: 3 additions & 3 deletions tests/fixtures/good_definition_files/debug_definition.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# noqa: INP001
from ehrql import create_dataset
from ehrql.debug import show, stop
from ehrql import create_dataset, debug
from ehrql.debugger import stop
from ehrql.tables.core import patients


dataset = create_dataset()
dataset.sex = patients.sex
show("Hello")
debug("Hello")
dataset.define_population(patients.date_of_birth.is_on_or_after("2000-01-01"))
stop()
dataset.year_of_birth = patients.date_of_birth.year
19 changes: 18 additions & 1 deletion tests/integration/query_engines/test_sandbox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from pathlib import Path

import pytest

from ehrql.query_engines.in_memory_database import PatientColumn
from ehrql.query_engines.sandbox import SandboxQueryEngine
from ehrql.query_engines.sandbox import EmptyDataset, SandboxQueryEngine
from ehrql.renderers import DISPLAY_RENDERERS
from ehrql.tables import PatientFrame, Series, table


Expand All @@ -17,3 +20,17 @@ def test_csv_query_engine_evaluate():
query_engine = SandboxQueryEngine(FIXTURES)
result = query_engine.evaluate(patients.sex)
assert result == PatientColumn({1: "M", 2: "F", 3: None})


@pytest.mark.parametrize(
"render_format,expected",
[
("ascii", "patient_id\n-----------------"),
(
"html",
"<table><thead><th>patient_id</th></thead><tbody><tr><td></td></tr></tbody></table>",
),
],
)
def test_empty_dataset_render_(render_format, expected):
assert EmptyDataset()._render_(DISPLAY_RENDERERS[render_format]).strip() == expected
71 changes: 20 additions & 51 deletions tests/integration/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,16 +241,15 @@ def test_generate_measures_dummy_tables(tmp_path, disclosure_control_enabled):
)


def test_debug_show(tmp_path, capsys):
def test_debug_debug(tmp_path, capsys):
definition = textwrap.dedent(
"""\
from ehrql import create_dataset
from ehrql.debug import show
from ehrql import create_dataset, debug
from ehrql.tables.core import patients
dataset = create_dataset()
year = patients.date_of_birth.year
show(6, label="Number")
debug(6, label="Number")
dataset.define_population(year>1980)
"""
)
Expand All @@ -277,75 +276,44 @@ def test_debug_show(tmp_path, capsys):

expected = textwrap.dedent(
"""\
Debug line 7: Number
Debug line 6: Number
6
"""
).strip()
assert capsys.readouterr().err.strip() == expected


@pytest.mark.parametrize(
"stop,expected_out",
"debug,stop,expected_out,expected_err",
[
(
"debug(dataset)",
"stop()",
"",
textwrap.dedent(
"""
Debug line 7:
patient_id
-----------------
1
2
3
4
"""
),
),
(
"stop(head=None, tail=None)",
textwrap.dedent(
"""
patient_id
-----------------
1
2
3
4
"""
),
),
(
"stop(head=1)",
textwrap.dedent(
"""
patient_id
-----------------
1
...
"""
),
),
(
"stop(tail=1)",
textwrap.dedent(
"""
patient_id
-----------------
...
4
Stopping at line 9
"""
),
),
("", "stop()", "", "Stopping at line 9"),
],
)
def test_debug_stop(tmp_path, capsys, stop, expected_out):
def test_debug_stop(tmp_path, capsys, debug, stop, expected_out, expected_err):
definition = textwrap.dedent(
f"""\
from ehrql import create_dataset
from ehrql.debug import show, stop
from ehrql import create_dataset, debug
from ehrql.debugger import stop
from ehrql.tables.core import patients
dataset = create_dataset()
year = patients.date_of_birth.year
{debug}
dataset.define_population(year>1900)
{stop}
"""
Expand Down Expand Up @@ -373,6 +341,7 @@ def test_debug_stop(tmp_path, capsys, stop, expected_out):
user_args=(),
)

assert (
capsys.readouterr().out.strip() == expected_out.strip()
), capsys.readouterr().out.strip()
captured = capsys.readouterr()

assert captured.out.strip() == expected_out.strip(), captured.out.strip()
assert captured.err.strip() == expected_err.strip(), captured.err.strip()
Loading

0 comments on commit 707fb3b

Please sign in to comment.