diff --git a/examples/pytest/BUILD.bazel b/examples/pytest/BUILD.bazel index 24c0d607..781b378f 100644 --- a/examples/pytest/BUILD.bazel +++ b/examples/pytest/BUILD.bazel @@ -1,23 +1,15 @@ -load("@aspect_rules_py//py:defs.bzl", "py_pytest_main", "py_test") +load("@aspect_rules_py//pytest:defs.bzl", "py_pytest_test") -py_pytest_main( - name = "__test__", - deps = ["@pypi_pytest//:pkg"], -) - -py_test( +py_pytest_test( name = "pytest_test", srcs = [ "foo_test.py", - ":__test__", ], imports = ["../.."], - main = ":__test__.py", package_collisions = "warning", + pip_repo = "pypi", deps = [ - ":__test__", "@pypi_ftfy//:pkg", "@pypi_neptune//:pkg", - "@pypi_pytest//:pkg", ], ) diff --git a/py/defs.bzl b/py/defs.bzl index 6fabf75a..083329d5 100644 --- a/py/defs.bzl +++ b/py/defs.bzl @@ -22,22 +22,26 @@ py_unpacked_wheel = _py_unpacked_wheel resolutions = _resolutions def _py_binary_or_test(name, rule, srcs, main, deps = [], resolutions = {}, **kwargs): + if main and type(main) not in ["string", "Label"]: + fail("main must be a Label or a string, not {}".format(type(main))) + # Compatibility with rules_python, see docs in py_executable.bzl main_target = "_{}.find_main".format(name) - determine_main( - name = main_target, - target_name = name, - main = main, - srcs = srcs, - **propagate_common_rule_attributes(kwargs) - ) + if type(main) != "Label": + determine_main( + name = main_target, + target_name = name, + main = main, + srcs = srcs, + **propagate_common_rule_attributes(kwargs) + ) package_collisions = kwargs.pop("package_collisions", None) rule( name = name, srcs = srcs, - main = main_target, + main = main if type(main) == "Label" else main_target, deps = deps, resolutions = resolutions, package_collisions = package_collisions, diff --git a/pytest/BUILD.bazel b/pytest/BUILD.bazel new file mode 100644 index 00000000..d54da6ca --- /dev/null +++ b/pytest/BUILD.bazel @@ -0,0 +1,4 @@ +exports_files( + ["pytest_shim.py"], + visibility = ["//visibility:public"], +) diff --git a/pytest/defs.bzl b/pytest/defs.bzl new file mode 100644 index 00000000..bf6a2e4f --- /dev/null +++ b/pytest/defs.bzl @@ -0,0 +1,65 @@ +"""Use pytest to run tests, using a wrapper script to interface with Bazel. + +Example: + +```starlark +load("@aspect_rules_py//pytest:defs.bzl", "py_pytest_test") + +py_pytest_test( + name = "test_w_pytest", + size = "small", + srcs = ["test.py"], +) +``` + +By default, `@pip//pytest` is added to `deps`. +If sharding is used (when `shard_count > 1`) then `@pip//pytest_shard` is also added. +To instead provide explicit deps for the pytest library, set `pytest_deps`: + +```starlark +py_pytest_test( + name = "test_w_my_pytest", + shard_count = 2, + srcs = ["test.py"], + pytest_deps = [requirement("pytest"), requirement("pytest-shard"), ...], +) +``` +""" + +load("//py:defs.bzl", "py_test") + +def py_pytest_test(name, srcs, deps = [], args = [], pytest_deps = None, pip_repo = "pip", **kwargs): + """ + Wrapper macro for `py_test` which supports pytest. + + Args: + name: A unique name for this target. + srcs: Python source files. + deps: Dependencies, typically `py_library`. + args: Additional command-line arguments to pytest. + See https://docs.pytest.org/en/latest/how-to/usage.html + pytest_deps: Labels of the pytest tool and other packages it may import. + pip_repo: Name of the external repository where Python packages are installed. + It's typically created by `pip.parse`. + This attribute is used only when `pytest_deps` is unset. + **kwargs: Additional named parameters to py_test. + """ + shim_label = Label("//pytest:pytest_shim.py") + + if pytest_deps == None: + pytest_deps = ["@{}//pytest".format(pip_repo)] + if kwargs.get("shard_count", 1) > 1: + pytest_deps.append("@{}//pytest_shard".format(pip_repo)) + + py_test( + name = name, + srcs = [ + shim_label, + ] + srcs, + main = shim_label, + args = [ + "--capture=no", + ] + args + ["$(location :%s)" % x for x in srcs], + deps = deps + pytest_deps, + **kwargs + ) diff --git a/pytest/pytest_shim.py b/pytest/pytest_shim.py new file mode 100644 index 00000000..28d3feb2 --- /dev/null +++ b/pytest/pytest_shim.py @@ -0,0 +1,71 @@ +"""A shim for executing pytest that supports test filtering, sharding, and more. + +Copied from https://github.com/caseyduquettesc/rules_python_pytest/blob/331e0e511130cf4859b7589a479db6c553974abf/python_pytest/pytest_shim.py +""" + +import sys +import os + +import pytest + + +if __name__ == "__main__": + pytest_args = ["--ignore=external"] + + args = sys.argv[1:] + # pytest < 8.0 runs tests twice if __init__.py is passed explicitly as an argument. + # Remove any __init__.py file to avoid that. + # pytest.version_tuple is available since pytest 7.0 + # https://github.com/pytest-dev/pytest/issues/9313 + if not hasattr(pytest, "version_tuple") or pytest.version_tuple < (8, 0): + args = [arg for arg in args if arg.startswith("-") or os.path.basename(arg) != "__init__.py"] + + if os.environ.get("XML_OUTPUT_FILE"): + pytest_args.append("--junitxml={xml_output_file}".format(xml_output_file=os.environ.get("XML_OUTPUT_FILE"))) + + # Handle test sharding - requires pytest-shard plugin. + if os.environ.get("TEST_SHARD_INDEX") and os.environ.get("TEST_TOTAL_SHARDS"): + pytest_args.append("--shard-id={shard_id}".format(shard_id=os.environ.get("TEST_SHARD_INDEX"))) + pytest_args.append("--num-shards={num_shards}".format(num_shards=os.environ.get("TEST_TOTAL_SHARDS"))) + if os.environ.get("TEST_SHARD_STATUS_FILE"): + open(os.environ["TEST_SHARD_STATUS_FILE"], "a").close() + + # Handle plugins that generate reports - if they are provided with relative paths (via args), + # re-write it under bazel's test undeclared outputs dir. + if os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR"): + undeclared_output_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + + # Flags that take file paths as value. + path_flags = [ + "--report-log", # pytest-reportlog + "--json-report-file", # pytest-json-report + "--html", # pytest-html + ] + for i, arg in enumerate(args): + for flag in path_flags: + if arg.startswith(f"{flag}="): + arg_split = arg.split("=", 1) + if len(arg_split) == 2 and not os.path.isabs(arg_split[1]): + args[i] = f"{flag}={undeclared_output_dir}/{arg_split[1]}" + + if os.environ.get("TESTBRIDGE_TEST_ONLY"): + test_filter = os.environ["TESTBRIDGE_TEST_ONLY"] + + # If the test filter does not start with a class-like name, then use test filtering instead + if not test_filter[0].isupper(): + # --test_filter=test_module.test_fn or --test_filter=test_module/test_file.py + pytest_args.extend(args) + pytest_args.append("-k={filter}".format(filter=test_filter)) + else: + # --test_filter=TestClass.test_fn + for arg in args: + if not arg.startswith("--"): + # arg is a src file. Add test class/method selection to it. + # test.py::TestClass::test_fn + arg = "{arg}::{module_fn}".format(arg=arg, module_fn=test_filter.replace(".", "::")) + pytest_args.append(arg) + else: + pytest_args.extend(args) + + print(pytest_args, file=sys.stderr) + raise SystemExit(pytest.main(pytest_args))