Skip to content

Commit

Permalink
Speed up test plan querying (#915)
Browse files Browse the repository at this point in the history
Speed up test plan querying via filtering test plans by OPERATORS filter
prior to vector filtering.
  • Loading branch information
vbrkicTT authored Jan 16, 2025
1 parent 12390bd commit 626ca63
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 78 deletions.
112 changes: 34 additions & 78 deletions forge/test/operators/pytorch/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@

# Examples
# pytest -svv forge/test/operators/pytorch/test_all.py::test_plan
# pytest -svv forge/test/operators/pytorch/test_all.py::test_failed
# pytest -svv forge/test/operators/pytorch/test_all.py::test_skipped
# pytest -svv forge/test/operators/pytorch/test_all.py::test_fatal
# pytest -svv forge/test/operators/pytorch/test_all.py::test_not_implemented
# pytest -svv forge/test/operators/pytorch/test_all.py::test_data_mismatch
# pytest -svv forge/test/operators/pytorch/test_all.py::test_unsupported_df
# pytest -svv forge/test/operators/pytorch/test_all.py::test_custom
# pytest -svv forge/test/operators/pytorch/test_all.py::test_query
# pytest -svv forge/test/operators/pytorch/test_all.py::test_unique
Expand All @@ -26,12 +20,14 @@

from loguru import logger
from tabulate import tabulate
from typing import List

from test.operators.utils import DeviceUtils
from test.operators.utils import InputSource
from test.operators.utils import TestVector
from test.operators.utils import TestCollection
from test.operators.utils import TestCollectionCommon
from test.operators.utils import TestSuite
from test.operators.utils import TestPlanScanner
from test.operators.utils import FailingReasons

Expand All @@ -58,8 +54,6 @@ class TestParamsData:

__test__ = False # Avoid collecting TestParamsData as a pytest test

test_suite = TestPlanScanner.build_test_suite(scan_file=__file__, scan_package=__package__)

@classmethod
def get_single_list(cls) -> list[str]:
"""Provide a list of test ids to run for test_single method"""
Expand Down Expand Up @@ -174,6 +168,19 @@ def get_filter_range(cls) -> tuple[int, int]:

return 0, 100000

@classmethod
def filter_suite_by_operators(cls, test_suite: TestSuite, operators: List[str]) -> TestSuite:
"""Filter test plans based on operator list to speed up test filtering"""
if operators is None:
return test_suite
else:
test_plans = [
test_plan
for test_plan in test_suite.test_plans
if len(list(set(test_plan.collections[0].operators) & set(operators))) > 0
]
return TestSuite(test_plans)


class TestCollectionData:
"""Helper test collections"""
Expand Down Expand Up @@ -208,6 +215,15 @@ class TestCollectionData:
)


class TestSuiteData:

__test__ = False # Avoid collecting TestSuiteData as a pytest test

all = TestPlanScanner.build_test_suite(scan_file=__file__, scan_package=__package__)

filtered = TestParamsData.filter_suite_by_operators(all, TestCollectionData.all.operators)


class VectorLambdas:
"""Helper lambdas for filtering test vectors"""

Expand Down Expand Up @@ -262,12 +278,9 @@ class VectorLambdas:
)


test_suite = TestParamsData.test_suite


@pytest.mark.parametrize(
"test_vector",
test_suite.query_all().filter(
TestSuiteData.filtered.query_all().filter(
VectorLambdas.ALL_OPERATORS,
VectorLambdas.QUICK,
# VectorLambdas.FILTERED,
Expand Down Expand Up @@ -304,7 +317,7 @@ def test_custom(test_vector: TestVector, test_device):

@pytest.mark.parametrize(
"test_vector",
test_suite.query_all()
TestSuiteData.filtered.query_all()
.filter(VectorLambdas.FILTERED)
.filter(*TestParamsData.build_filter_lambdas())
.range(*TestParamsData.get_filter_range())
Expand All @@ -316,7 +329,7 @@ def test_query(test_vector: TestVector, test_device):

@pytest.mark.parametrize(
"test_vector",
test_suite.query_all()
TestSuiteData.filtered.query_all()
.filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.SINGLE_SHAPE)
.filter(
lambda test_vector: test_vector.input_source in [InputSource.FROM_HOST]
Expand All @@ -330,13 +343,17 @@ def test_unique(test_vector: TestVector, test_device):
TestVerification.verify(test_vector, test_device)


@pytest.mark.parametrize("test_vector", test_suite.query_from_id_list(TestParamsData.get_single_list()).to_params())
@pytest.mark.parametrize(
"test_vector", TestSuiteData.all.query_from_id_list(TestParamsData.get_single_list()).to_params()
)
def test_single(test_vector: TestVector, test_device):
TestVerification.verify(test_vector, test_device)


@pytest.mark.nightly_sweeps
@pytest.mark.parametrize("test_vector", test_suite.query_all().filter(VectorLambdas.ALL_OPERATORS).to_params())
@pytest.mark.parametrize(
"test_vector", TestSuiteData.filtered.query_all().filter(VectorLambdas.ALL_OPERATORS).to_params()
)
def test_plan(test_vector: TestVector, test_device):
TestVerification.verify(test_vector, test_device)

Expand All @@ -361,67 +378,6 @@ def test_plan(test_vector: TestVector, test_device):
# Below are examples of custom test functions that utilize filtering lambdas to run specific tests


@pytest.mark.parametrize(
"test_vector",
test_suite.query_all()
.filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.NOT_SKIPED)
.to_params(),
)
def test_failed(test_vector: TestVector, test_device):
TestVerification.verify(test_vector, test_device)


@pytest.mark.parametrize(
"test_vector",
test_suite.query_all().filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.SKIPED).to_params(),
)
def test_skipped(test_vector: TestVector, test_device):
TestVerification.verify(test_vector, test_device)


@pytest.mark.parametrize(
"test_vector",
test_suite.query_all()
.filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.SKIPED_FATAL)
.to_params(),
)
def test_fatal(test_vector: TestVector, test_device):
TestVerification.verify(test_vector, test_device)


@pytest.mark.parametrize(
"test_vector",
test_suite.query_all()
.filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.NOT_SKIPED)
.filter(VectorLambdas.UNSUPPORTED_DATA_FORMAT)
.to_params(),
)
def test_unsupported_df(test_vector: TestVector, test_device):
TestVerification.verify(test_vector, test_device)


@pytest.mark.parametrize(
"test_vector",
test_suite.query_all()
.filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.NOT_SKIPED)
.filter(VectorLambdas.NOT_IMPLEMENTED)
.to_params(),
)
def test_not_implemented(test_vector: TestVector, test_device):
TestVerification.verify(test_vector, test_device)


@pytest.mark.parametrize(
"test_vector",
test_suite.query_all()
.filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.NOT_SKIPED)
.filter(VectorLambdas.DATA_MISMATCH)
.to_params(),
)
def test_data_mismatch(test_vector: TestVector, test_device):
TestVerification.verify(test_vector, test_device)


class InfoUtils:
@classmethod
def print_query_params(cls, max_width=80):
Expand All @@ -433,7 +389,7 @@ def print_query_params(cls, max_width=80):
@classmethod
def print_query_values(cls, max_width=80):

operators = [key for key in test_suite.indices]
operators = [key for key in TestSuiteData.all.indices]
operators = sorted(operators)
operators = ", ".join(operators)

Expand Down
4 changes: 4 additions & 0 deletions forge/test/operators/utils/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def to_params(self) -> Generator[ParameterSet, None, None]:
test_vectors = self.test_vectors
for test_vector in test_vectors:
yield test_vector.to_param()
logger.trace("To params done")

@classmethod
def all(cls, test_plan: Union["TestPlan", "TestSuite"]) -> "TestQuery":
Expand Down Expand Up @@ -480,6 +481,7 @@ def load_test_vectors_from_id_list(self, test_ids: List[str]) -> List[TestVector
return test_vectors

def query_all(self) -> TestQuery:
logger.trace("Query all test vectors")
return TestQuery.all(self)

def query_from_id_file(self, test_ids_file: str) -> TestQuery:
Expand Down Expand Up @@ -819,6 +821,8 @@ def get_all_test_plans(cls, scan_file: str, scan_package: str) -> Generator[Test
@classmethod
def build_test_suite(cls, scan_file: str, scan_package: str) -> TestSuite:
"""Build test suite from scaned test plans."""
logger.trace(f"Building test suite from file: {scan_file} and package: {scan_package}")
test_plans = cls.get_all_test_plans(scan_file, scan_package)
test_plans = list(test_plans)
logger.trace(f"Found test plans: {len(test_plans)}")
return TestSuite(test_plans=test_plans)

0 comments on commit 626ca63

Please sign in to comment.