diff --git a/forge/test/operators/pytorch/conftest.py b/forge/test/operators/pytorch/conftest.py index 6ed95c6ea..a1b7bb73c 100644 --- a/forge/test/operators/pytorch/conftest.py +++ b/forge/test/operators/pytorch/conftest.py @@ -14,6 +14,8 @@ from test.operators.utils import PyTestUtils from test.operators.utils import FailingReasonsValidation +from ..utils import TestPlanUtils + @pytest.hookimpl(hookwrapper=True) def pytest_runtest_makereport(item: _pytest.python.Function, call: _pytest.runner.CallInfo): @@ -23,6 +25,8 @@ def pytest_runtest_makereport(item: _pytest.python.Function, call: _pytest.runne # This hook function is called after each step of the test execution (setup, call, teardown) if call.when == "call": # 'call' is a phase when the test is actually executed + xfail_reason = PyTestUtils.get_xfail_reason(item) + if call.excinfo is not None: # an exception occurred during the test execution logger.trace( @@ -30,10 +34,48 @@ def pytest_runtest_makereport(item: _pytest.python.Function, call: _pytest.runne ) exception_value = call.excinfo.value - xfail_reason = PyTestUtils.get_xfail_reason(item) + if xfail_reason is not None: # an xfail reason is defined for the test valid_reason = FailingReasonsValidation.validate_exception(exception_value, xfail_reason) - # if reason is not valid, mark the test as failed + # if reason is not valid, mark the test as failed and keep the original exception if valid_reason == False: - report.outcome = "failed" + # Replace test report with a new one with outcome set to 'failed' and exception details + new_report = _pytest.reports.TestReport( + item=item, + when=call.when, + outcome="failed", + longrepr=call.excinfo.getrepr(style="long"), + sections=report.sections, + nodeid=report.nodeid, + location=report.location, + keywords=report.keywords, + ) + outcome.force_result(new_report) + else: + logger.debug(f"Test '{item.name}' failed with exception: {type(exception_value)} '{exception_value}'") + + log_test_vector_properties(item, report, xfail_reason) + + +def log_test_vector_properties(item: _pytest.python.Function, report: _pytest.reports.TestReport, xfail_reason: str): + original_name = item.originalname + test_id = item.name + test_id = test_id.replace(f"{original_name}[", "") + test_id = test_id.replace("]", "") + test_vector = TestPlanUtils.test_id_to_test_vector(test_id) + + item.user_properties.append(("id", test_id)) + item.user_properties.append(("operator", test_vector.operator)) + item.user_properties.append( + ("input_source", test_vector.input_source.name if test_vector.input_source is not None else None) + ) + item.user_properties.append(("dev_data_format", TestPlanUtils.dev_data_format_to_str(test_vector.dev_data_format))) + item.user_properties.append( + ("math_fidelity", test_vector.math_fidelity.name if test_vector.math_fidelity is not None else None) + ) + item.user_properties.append(("input_shape", test_vector.input_shape)) + item.user_properties.append(("kwargs", test_vector.kwargs)) + if xfail_reason is not None: + item.user_properties.append(("xfail_reason", xfail_reason)) + item.user_properties.append(("outcome", report.outcome)) diff --git a/forge/test/operators/pytorch/test_all.py b/forge/test/operators/pytorch/test_all.py index ddf39ca25..c9357a696 100644 --- a/forge/test/operators/pytorch/test_all.py +++ b/forge/test/operators/pytorch/test_all.py @@ -96,7 +96,7 @@ def build_filtered_collection(cls) -> TestCollection: dev_data_formats = os.getenv("DEV_DATA_FORMATS", None) if dev_data_formats: dev_data_formats = dev_data_formats.split(",") - dev_data_formats = [getattr(forge.DataFormat, dev_data_format) for dev_data_format in dev_data_formats] + dev_data_formats = [TestPlanUtils.dev_data_format_from_str(dev_data_format) for dev_data_format in dev_data_formats] math_fidelities = os.getenv("MATH_FIDELITIES", None) if math_fidelities: diff --git a/forge/test/operators/utils/__init__.py b/forge/test/operators/utils/__init__.py index 79780d22c..c173e4297 100644 --- a/forge/test/operators/utils/__init__.py +++ b/forge/test/operators/utils/__init__.py @@ -5,6 +5,7 @@ from .datatypes import OperatorParameterTypes from .datatypes import ValueRange from .datatypes import ValueRanges +from .datatypes import DataTypeTestPlan from .utils import ShapeUtils from .utils import InputSourceFlag, InputSourceFlags from .utils import CompilerUtils @@ -24,6 +25,7 @@ from .plan import FailingRulesConverter from .plan import TestPlanScanner from .test_data import TestCollectionCommon +from .test_data import TestCollectionTorch from .failing_reasons import FailingReasons from .failing_reasons import FailingReasonsValidation from .pytest import PyTestUtils @@ -33,6 +35,7 @@ "OperatorParameterTypes", "ValueRange", "ValueRanges", + "DataTypeTestPlan", "ShapeUtils", "InputSourceFlag", "InputSourceFlags", @@ -53,6 +56,7 @@ "FailingRulesConverter", "TestPlanScanner", "TestCollectionCommon", + "TestCollectionTorch", "FailingReasons", "FailingReasonsValidation", "PyTestUtils", diff --git a/forge/test/operators/utils/compat.py b/forge/test/operators/utils/compat.py index 9a546ce5e..5ceb76167 100644 --- a/forge/test/operators/utils/compat.py +++ b/forge/test/operators/utils/compat.py @@ -15,6 +15,7 @@ from forge.verify.compare import compare_with_golden from .datatypes import OperatorParameterTypes, ValueRanges, ValueRange +from .datatypes import DataTypeTestPlan # TODO - Remove this class once TestDevice is available in Forge @@ -89,10 +90,12 @@ class TestTensorsUtils: torch.bfloat16: (-10000, 10000), torch.float16: (-10000, 10000), torch.float32: (-10000, 10000), + torch.float64: (-10000, 10000), torch.uint8: (0, 2**8 - 1), torch.int8: (-(2**7), 2**7 - 1), torch.int16: (-(2**15), 2**15 - 1), torch.int32: (-(2**31), 2**31 - 1), + torch.int64: (-(2**63), 2**63 - 1), } class DTypes: @@ -102,12 +105,14 @@ class DTypes: torch.float16, torch.bfloat16, torch.float32, + torch.float64, ) integers = ( torch.uint8, torch.int8, torch.int16, torch.int32, + torch.int64, ) booleans = (torch.bool,) @@ -170,6 +175,8 @@ def get_dtype_for_df(cls, dev_data_format: forge.DataFormat = None) -> torch.dty if dev_data_format is None: dtype = None + elif isinstance(dev_data_format, torch.dtype): + dtype = dev_data_format else: # dtype = torch.float32 if dev_data_format in cls.dev_data_format_to_dtype: @@ -226,7 +233,7 @@ def verify_module( model: Module, input_shapes: List[TensorShape], pcc: Optional[float] = None, - dev_data_format: forge.DataFormat = None, + dev_data_format: DataTypeTestPlan = None, value_range: Optional[Union[ValueRanges, ValueRange, OperatorParameterTypes.RangeValue]] = None, random_seed: int = 42, ): @@ -243,7 +250,7 @@ def verify_module( # TODO move to class TestTensorsUtils def create_torch_inputs( input_shapes: List[TensorShape], - dev_data_format: forge.DataFormat = None, + dev_data_format: DataTypeTestPlan = None, value_range: Optional[Union[ValueRanges, ValueRange, OperatorParameterTypes.RangeValue]] = None, random_seed: Optional[int] = None, ) -> List[torch.Tensor]: @@ -254,6 +261,8 @@ def create_torch_inputs( generator = torch.Generator().manual_seed(random_seed) dtype = TestTensorsUtils.get_dtype_for_df(dev_data_format) + # dev_data_format is no more used + dev_data_format = None # if dtype is not None: # torch.set_default_dtype(dtype) @@ -280,7 +289,13 @@ def verify_module_for_inputs( fw_out = model(*inputs) - forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs] + # TODO introduce flag to check framework of forge data formats + + if dev_data_format is None or isinstance(dev_data_format, forge.DataFormat): + forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs] + elif isinstance(dev_data_format, torch.dtype): + forge_inputs = inputs + # forge_inputs = inputs compiled_model = forge.compile(model, sample_inputs=forge_inputs) co_out = compiled_model(*forge_inputs) diff --git a/forge/test/operators/utils/datatypes.py b/forge/test/operators/utils/datatypes.py index 15be739b0..a50908bc8 100644 --- a/forge/test/operators/utils/datatypes.py +++ b/forge/test/operators/utils/datatypes.py @@ -8,6 +8,12 @@ from enum import Enum from typing import Optional, Dict, Union, Tuple, TypeAlias +import torch +import forge + + +DataTypeTestPlan = Union[forge.DataFormat, torch.dtype] + class OperatorParameterTypes: SingleValue: TypeAlias = Union[int, float] diff --git a/forge/test/operators/utils/plan.py b/forge/test/operators/utils/plan.py index 03aaf99e5..31fdbd6d0 100644 --- a/forge/test/operators/utils/plan.py +++ b/forge/test/operators/utils/plan.py @@ -27,6 +27,7 @@ from forge.op_repo import TensorShape from .datatypes import OperatorParameterTypes +from .datatypes import DataTypeTestPlan from .pytest import PytestParamsUtils from .compat import TestDevice @@ -84,7 +85,7 @@ class TestVector: input_source: InputSource input_shape: TensorShape # TODO - Support multiple input shapes number_of_operands: Optional[int] = None - dev_data_format: Optional[DataFormat] = None + dev_data_format: Optional[DataTypeTestPlan] = None math_fidelity: Optional[MathFidelity] = None kwargs: Optional[OperatorParameterTypes.Kwargs] = None pcc: Optional[float] = None @@ -94,7 +95,7 @@ class TestVector: def get_id(self, fields: Optional[List[str]] = None) -> str: """Get test vector id""" if fields is None: - return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}{'-' + str(self.number_of_operands) + '-' if self.number_of_operands else '-'}{self.dev_data_format.name if self.dev_data_format else None}-{self.math_fidelity.name if self.math_fidelity else None}" + return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}{'-' + str(self.number_of_operands) + '-' if self.number_of_operands else '-'}{TestPlanUtils.dev_data_format_to_str(self.dev_data_format)}-{self.math_fidelity.name if self.math_fidelity else None}" else: attr = [ (getattr(self, field).name if getattr(self, field) is not None else None) @@ -142,7 +143,7 @@ class TestCollection: input_sources: Optional[List[InputSource]] = None input_shapes: Optional[List[TensorShape]] = None # TODO - Support multiple input shapes numbers_of_operands: Optional[List[int]] = None - dev_data_formats: Optional[List[DataFormat]] = None + dev_data_formats: Optional[List[DataTypeTestPlan]] = None math_fidelities: Optional[List[MathFidelity]] = None kwargs: Optional[ Union[List[OperatorParameterTypes.Kwargs], Callable[["TestVector"], List[OperatorParameterTypes.Kwargs]]] @@ -494,6 +495,32 @@ class TestPlanUtils: Utility functions for test vectors """ + @classmethod + def dev_data_format_to_str(cls, dev_data_format: DataTypeTestPlan) -> Optional[str]: + """Convert data format to string""" + if dev_data_format is None: + return None + if isinstance(dev_data_format, DataFormat): + return dev_data_format.name + if isinstance(dev_data_format, torch.dtype): + # Remove torch. prefix + return str(dev_data_format).split(".")[-1] + else: + raise ValueError(f"Unsupported data format: {dev_data_format}") + + @classmethod + def dev_data_format_from_str(cls, dev_data_format_str: str) -> DataTypeTestPlan: + """Convert string to data format""" + if dev_data_format_str is None: + return None + if hasattr(forge.DataFormat, dev_data_format_str): + dev_data_format = getattr(forge.DataFormat, dev_data_format_str) + elif hasattr(torch, dev_data_format_str): + dev_data_format = getattr(torch, dev_data_format_str) + else: + raise ValueError(f"Unsupported data format: {dev_data_format_str} in Forge and PyTorch") + return dev_data_format + @classmethod def _match(cls, rule_collection: Optional[List], vector_value): """ @@ -624,7 +651,7 @@ def test_id_to_test_vector(cls, test_id: str) -> TestVector: dev_data_format_part = parts[dev_data_format_index] if dev_data_format_part == "None": dev_data_format_part = None - dev_data_format = eval(f"forge._C.{dev_data_format_part}") if dev_data_format_part is not None else None + dev_data_format = cls.dev_data_format_from_str(dev_data_format_part) math_fidelity_part = parts[math_fidelity_index] if math_fidelity_part == "None": @@ -663,7 +690,7 @@ def build_rules( Union[Optional[InputSource], List[InputSource]], Union[Optional[TensorShape], List[TensorShape]], Union[Optional[OperatorParameterTypes.Kwargs], List[OperatorParameterTypes.Kwargs]], - Union[Optional[forge.DataFormat], List[forge.DataFormat]], + Union[Optional[DataTypeTestPlan], List[DataTypeTestPlan]], Union[Optional[forge.MathFidelity], List[forge.MathFidelity]], Optional[TestResultFailing], ], @@ -708,7 +735,7 @@ def build_rule( input_source: Optional[Union[InputSource, List[InputSource]]], input_shape: Optional[Union[TensorShape, List[TensorShape]]], kwargs: Optional[Union[OperatorParameterTypes.Kwargs, List[OperatorParameterTypes.Kwargs]]], - dev_data_format: Optional[Union[forge.DataFormat, List[forge.DataFormat]]], + dev_data_format: Optional[Union[DataTypeTestPlan, List[DataTypeTestPlan]]], math_fidelity: Optional[Union[forge.MathFidelity, List[forge.MathFidelity]]], result_failing: Optional[TestResultFailing], ) -> TestCollection: diff --git a/forge/test/operators/utils/test_data.py b/forge/test/operators/utils/test_data.py index b8ca77138..cddfc7272 100644 --- a/forge/test/operators/utils/test_data.py +++ b/forge/test/operators/utils/test_data.py @@ -6,6 +6,7 @@ import pytest import forge +import torch from forge import MathFidelity, DataFormat @@ -309,3 +310,44 @@ class TestCollectionCommon: (14, 13, 89, 3), # 4.2 Prime numbers ] ) + + +class TestCollectionTorch: + """ + Shared test collection for torch data types. + """ + + __test__ = False # Avoid collecting TestCollectionTorch as a pytest test + + float = TestCollection( + dev_data_formats=[ + torch.float16, + torch.float32, + # torch.float64, + torch.bfloat16, + ], + ) + + int = TestCollection( + dev_data_formats=[ + torch.int8, + # torch.int16, + torch.int32, + torch.int64, + # torch.uint8, + ], + ) + + bool = TestCollection( + dev_data_formats=[ + torch.bool, + ], + ) + + all = TestCollection(dev_data_formats=float.dev_data_formats + int.dev_data_formats) + + single = TestCollection( + dev_data_formats=[ + torch.float16, + ], + )