Skip to content

Commit

Permalink
Use black python API to fetch the list of checked files
Browse files Browse the repository at this point in the history
  • Loading branch information
nachovizzo committed Dec 15, 2023
1 parent 099652c commit 5079183
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions ament_black/ament_black/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@
import time
from xml.sax.saxutils import escape, quoteattr

import click
from black import get_sources
from black import main as black
from black import re_compile_maybe_verbose
from black.concurrency import maybe_install_uvloop
from black.const import DEFAULT_EXCLUDES, DEFAULT_INCLUDES
from black.report import Report
from unidiff import PatchSet


def patched_black(*args, **kwargs) -> None:
from multiprocessing import freeze_support

from black import main as black
from black.concurrency import maybe_install_uvloop

maybe_install_uvloop()
freeze_support()
black(*args, **kwargs)
Expand Down Expand Up @@ -71,9 +75,23 @@ def main(argv=sys.argv[1:]):
print("Could not find config file '%s'" % args.config_file, file=sys.stderr)
return 1

# TODO(Nacho): Inject the config file results into the ctx (use read_pyproject_toml)
sources = get_sources(
ctx=click.Context(black),
src=tuple(args.paths),
quiet=True,
verbose=False,
include=re_compile_maybe_verbose(DEFAULT_INCLUDES),
exclude=re_compile_maybe_verbose(DEFAULT_EXCLUDES),
extend_exclude=None,
force_exclude=None,
report=Report(),
stdin_filename="",
)
checked_files = [str(path) for path in sources]

if args.xunit_file:
start_time = time.time()
report = []

# invoke black
black_args_withouth_path = []
Expand Down Expand Up @@ -127,7 +145,9 @@ def main(argv=sys.argv[1:]):
file_name = file_name.split(suffix)[0]
testname = "%s.%s" % (folder_name, file_name)

xml = get_xunit_content(report, testname, time.time() - start_time)
xml = get_xunit_content(
report, testname, time.time() - start_time, checked_files
)
path = os.path.dirname(os.path.abspath(args.xunit_file))
if not os.path.exists(path):
os.makedirs(path)
Expand Down Expand Up @@ -157,7 +177,7 @@ def get_line_number(data, offset):
return data[0:offset].count("\n") + data[0:offset].count("\r") + 1


def get_xunit_content(report, testname, elapsed):
def get_xunit_content(report, testname, elapsed, checked_files):
test_count = sum(max(len(r), 1) for r in report.values())
error_count = sum(len(r) for r in report.values())
data = {
Expand Down Expand Up @@ -215,10 +235,10 @@ def get_xunit_content(report, testname, elapsed):

# output list of checked files
data = {
"escaped_files": escape("".join(["\n* %s" % r for r in sorted(report.keys())]))
"checked_files": escape("".join(["\n* %s" % r for r in sorted(checked_files)]))
}
xml += (
""" <system-out>Checked files:%(escaped_files)s</system-out>
""" <system-out>Checked files:%(checked_files)s</system-out>
"""
% data
)
Expand Down

0 comments on commit 5079183

Please sign in to comment.