Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch dtypes in test plan #1052

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions forge/test/operators/pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from test.operators.utils import PyTestUtils
from test.operators.utils import FailingReasonsValidation

from ..utils import TestPlanUtils


@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(item: _pytest.python.Function, call: _pytest.runner.CallInfo):
Expand All @@ -23,17 +25,57 @@ def pytest_runtest_makereport(item: _pytest.python.Function, call: _pytest.runne
# This hook function is called after each step of the test execution (setup, call, teardown)
if call.when == "call": # 'call' is a phase when the test is actually executed

xfail_reason = PyTestUtils.get_xfail_reason(item)

if call.excinfo is not None: # an exception occurred during the test execution

logger.trace(
f"Test: skipped: {report.skipped} failed: {report.failed} passed: {report.passed} report: {report}"
)

exception_value = call.excinfo.value
xfail_reason = PyTestUtils.get_xfail_reason(item)

if xfail_reason is not None: # an xfail reason is defined for the test
valid_reason = FailingReasonsValidation.validate_exception(exception_value, xfail_reason)

# if reason is not valid, mark the test as failed
# if reason is not valid, mark the test as failed and keep the original exception
if valid_reason == False:
report.outcome = "failed"
# Replace test report with a new one with outcome set to 'failed' and exception details
new_report = _pytest.reports.TestReport(
item=item,
when=call.when,
outcome="failed",
longrepr=call.excinfo.getrepr(style="long"),
sections=report.sections,
nodeid=report.nodeid,
location=report.location,
keywords=report.keywords,
)
outcome.force_result(new_report)
else:
logger.debug(f"Test '{item.name}' failed with exception: {type(exception_value)} '{exception_value}'")

log_test_vector_properties(item, report, xfail_reason)


def log_test_vector_properties(item: _pytest.python.Function, report: _pytest.reports.TestReport, xfail_reason: str):
original_name = item.originalname
test_id = item.name
test_id = test_id.replace(f"{original_name}[", "")
test_id = test_id.replace("]", "")
test_vector = TestPlanUtils.test_id_to_test_vector(test_id)

item.user_properties.append(("id", test_id))
item.user_properties.append(("operator", test_vector.operator))
item.user_properties.append(
("input_source", test_vector.input_source.name if test_vector.input_source is not None else None)
)
item.user_properties.append(("dev_data_format", TestPlanUtils.dev_data_format_to_str(test_vector.dev_data_format)))
item.user_properties.append(
("math_fidelity", test_vector.math_fidelity.name if test_vector.math_fidelity is not None else None)
)
item.user_properties.append(("input_shape", test_vector.input_shape))
item.user_properties.append(("kwargs", test_vector.kwargs))
if xfail_reason is not None:
item.user_properties.append(("xfail_reason", xfail_reason))
item.user_properties.append(("outcome", report.outcome))
2 changes: 1 addition & 1 deletion forge/test/operators/pytorch/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def build_filtered_collection(cls) -> TestCollection:
dev_data_formats = os.getenv("DEV_DATA_FORMATS", None)
if dev_data_formats:
dev_data_formats = dev_data_formats.split(",")
dev_data_formats = [getattr(forge.DataFormat, dev_data_format) for dev_data_format in dev_data_formats]
dev_data_formats = [TestPlanUtils.dev_data_format_from_str(dev_data_format) for dev_data_format in dev_data_formats]

math_fidelities = os.getenv("MATH_FIDELITIES", None)
if math_fidelities:
Expand Down
4 changes: 4 additions & 0 deletions forge/test/operators/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .datatypes import OperatorParameterTypes
from .datatypes import ValueRange
from .datatypes import ValueRanges
from .datatypes import DataTypeTestPlan
from .utils import ShapeUtils
from .utils import InputSourceFlag, InputSourceFlags
from .utils import CompilerUtils
Expand All @@ -24,6 +25,7 @@
from .plan import FailingRulesConverter
from .plan import TestPlanScanner
from .test_data import TestCollectionCommon
from .test_data import TestCollectionTorch
from .failing_reasons import FailingReasons
from .failing_reasons import FailingReasonsValidation
from .pytest import PyTestUtils
Expand All @@ -33,6 +35,7 @@
"OperatorParameterTypes",
"ValueRange",
"ValueRanges",
"DataTypeTestPlan",
"ShapeUtils",
"InputSourceFlag",
"InputSourceFlags",
Expand All @@ -53,6 +56,7 @@
"FailingRulesConverter",
"TestPlanScanner",
"TestCollectionCommon",
"TestCollectionTorch",
"FailingReasons",
"FailingReasonsValidation",
"PyTestUtils",
Expand Down
21 changes: 18 additions & 3 deletions forge/test/operators/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from forge.verify.compare import compare_with_golden

from .datatypes import OperatorParameterTypes, ValueRanges, ValueRange
from .datatypes import DataTypeTestPlan


# TODO - Remove this class once TestDevice is available in Forge
Expand Down Expand Up @@ -89,10 +90,12 @@ class TestTensorsUtils:
torch.bfloat16: (-10000, 10000),
torch.float16: (-10000, 10000),
torch.float32: (-10000, 10000),
torch.float64: (-10000, 10000),
torch.uint8: (0, 2**8 - 1),
torch.int8: (-(2**7), 2**7 - 1),
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1),
torch.int64: (-(2**63), 2**63 - 1),
}

class DTypes:
Expand All @@ -102,12 +105,14 @@ class DTypes:
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
)
integers = (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
booleans = (torch.bool,)

Expand Down Expand Up @@ -170,6 +175,8 @@ def get_dtype_for_df(cls, dev_data_format: forge.DataFormat = None) -> torch.dty

if dev_data_format is None:
dtype = None
elif isinstance(dev_data_format, torch.dtype):
dtype = dev_data_format
else:
# dtype = torch.float32
if dev_data_format in cls.dev_data_format_to_dtype:
Expand Down Expand Up @@ -226,7 +233,7 @@ def verify_module(
model: Module,
input_shapes: List[TensorShape],
pcc: Optional[float] = None,
dev_data_format: forge.DataFormat = None,
dev_data_format: DataTypeTestPlan = None,
value_range: Optional[Union[ValueRanges, ValueRange, OperatorParameterTypes.RangeValue]] = None,
random_seed: int = 42,
):
Expand All @@ -243,7 +250,7 @@ def verify_module(
# TODO move to class TestTensorsUtils
def create_torch_inputs(
input_shapes: List[TensorShape],
dev_data_format: forge.DataFormat = None,
dev_data_format: DataTypeTestPlan = None,
value_range: Optional[Union[ValueRanges, ValueRange, OperatorParameterTypes.RangeValue]] = None,
random_seed: Optional[int] = None,
) -> List[torch.Tensor]:
Expand All @@ -254,6 +261,8 @@ def create_torch_inputs(
generator = torch.Generator().manual_seed(random_seed)

dtype = TestTensorsUtils.get_dtype_for_df(dev_data_format)
# dev_data_format is no more used
dev_data_format = None

# if dtype is not None:
# torch.set_default_dtype(dtype)
Expand All @@ -280,7 +289,13 @@ def verify_module_for_inputs(

fw_out = model(*inputs)

forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
# TODO introduce flag to check framework of forge data formats

if dev_data_format is None or isinstance(dev_data_format, forge.DataFormat):
forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
elif isinstance(dev_data_format, torch.dtype):
forge_inputs = inputs
# forge_inputs = inputs

compiled_model = forge.compile(model, sample_inputs=forge_inputs)
co_out = compiled_model(*forge_inputs)
Expand Down
6 changes: 6 additions & 0 deletions forge/test/operators/utils/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from enum import Enum
from typing import Optional, Dict, Union, Tuple, TypeAlias

import torch
import forge


DataTypeTestPlan = Union[forge.DataFormat, torch.dtype]


class OperatorParameterTypes:
SingleValue: TypeAlias = Union[int, float]
Expand Down
39 changes: 33 additions & 6 deletions forge/test/operators/utils/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from forge.op_repo import TensorShape

from .datatypes import OperatorParameterTypes
from .datatypes import DataTypeTestPlan
from .pytest import PytestParamsUtils
from .compat import TestDevice

Expand Down Expand Up @@ -84,7 +85,7 @@ class TestVector:
input_source: InputSource
input_shape: TensorShape # TODO - Support multiple input shapes
number_of_operands: Optional[int] = None
dev_data_format: Optional[DataFormat] = None
dev_data_format: Optional[DataTypeTestPlan] = None
math_fidelity: Optional[MathFidelity] = None
kwargs: Optional[OperatorParameterTypes.Kwargs] = None
pcc: Optional[float] = None
Expand All @@ -94,7 +95,7 @@ class TestVector:
def get_id(self, fields: Optional[List[str]] = None) -> str:
"""Get test vector id"""
if fields is None:
return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}{'-' + str(self.number_of_operands) + '-' if self.number_of_operands else '-'}{self.dev_data_format.name if self.dev_data_format else None}-{self.math_fidelity.name if self.math_fidelity else None}"
return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}{'-' + str(self.number_of_operands) + '-' if self.number_of_operands else '-'}{TestPlanUtils.dev_data_format_to_str(self.dev_data_format)}-{self.math_fidelity.name if self.math_fidelity else None}"
else:
attr = [
(getattr(self, field).name if getattr(self, field) is not None else None)
Expand Down Expand Up @@ -142,7 +143,7 @@ class TestCollection:
input_sources: Optional[List[InputSource]] = None
input_shapes: Optional[List[TensorShape]] = None # TODO - Support multiple input shapes
numbers_of_operands: Optional[List[int]] = None
dev_data_formats: Optional[List[DataFormat]] = None
dev_data_formats: Optional[List[DataTypeTestPlan]] = None
math_fidelities: Optional[List[MathFidelity]] = None
kwargs: Optional[
Union[List[OperatorParameterTypes.Kwargs], Callable[["TestVector"], List[OperatorParameterTypes.Kwargs]]]
Expand Down Expand Up @@ -494,6 +495,32 @@ class TestPlanUtils:
Utility functions for test vectors
"""

@classmethod
def dev_data_format_to_str(cls, dev_data_format: DataTypeTestPlan) -> Optional[str]:
"""Convert data format to string"""
if dev_data_format is None:
return None
if isinstance(dev_data_format, DataFormat):
return dev_data_format.name
if isinstance(dev_data_format, torch.dtype):
# Remove torch. prefix
return str(dev_data_format).split(".")[-1]
else:
raise ValueError(f"Unsupported data format: {dev_data_format}")

@classmethod
def dev_data_format_from_str(cls, dev_data_format_str: str) -> DataTypeTestPlan:
"""Convert string to data format"""
if dev_data_format_str is None:
return None
if hasattr(forge.DataFormat, dev_data_format_str):
dev_data_format = getattr(forge.DataFormat, dev_data_format_str)
elif hasattr(torch, dev_data_format_str):
dev_data_format = getattr(torch, dev_data_format_str)
else:
raise ValueError(f"Unsupported data format: {dev_data_format_str} in Forge and PyTorch")
return dev_data_format

@classmethod
def _match(cls, rule_collection: Optional[List], vector_value):
"""
Expand Down Expand Up @@ -624,7 +651,7 @@ def test_id_to_test_vector(cls, test_id: str) -> TestVector:
dev_data_format_part = parts[dev_data_format_index]
if dev_data_format_part == "None":
dev_data_format_part = None
dev_data_format = eval(f"forge._C.{dev_data_format_part}") if dev_data_format_part is not None else None
dev_data_format = cls.dev_data_format_from_str(dev_data_format_part)

math_fidelity_part = parts[math_fidelity_index]
if math_fidelity_part == "None":
Expand Down Expand Up @@ -663,7 +690,7 @@ def build_rules(
Union[Optional[InputSource], List[InputSource]],
Union[Optional[TensorShape], List[TensorShape]],
Union[Optional[OperatorParameterTypes.Kwargs], List[OperatorParameterTypes.Kwargs]],
Union[Optional[forge.DataFormat], List[forge.DataFormat]],
Union[Optional[DataTypeTestPlan], List[DataTypeTestPlan]],
Union[Optional[forge.MathFidelity], List[forge.MathFidelity]],
Optional[TestResultFailing],
],
Expand Down Expand Up @@ -708,7 +735,7 @@ def build_rule(
input_source: Optional[Union[InputSource, List[InputSource]]],
input_shape: Optional[Union[TensorShape, List[TensorShape]]],
kwargs: Optional[Union[OperatorParameterTypes.Kwargs, List[OperatorParameterTypes.Kwargs]]],
dev_data_format: Optional[Union[forge.DataFormat, List[forge.DataFormat]]],
dev_data_format: Optional[Union[DataTypeTestPlan, List[DataTypeTestPlan]]],
math_fidelity: Optional[Union[forge.MathFidelity, List[forge.MathFidelity]]],
result_failing: Optional[TestResultFailing],
) -> TestCollection:
Expand Down
42 changes: 42 additions & 0 deletions forge/test/operators/utils/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import forge
import torch

from forge import MathFidelity, DataFormat

Expand Down Expand Up @@ -309,3 +310,44 @@ class TestCollectionCommon:
(14, 13, 89, 3), # 4.2 Prime numbers
]
)


class TestCollectionTorch:
"""
Shared test collection for torch data types.
"""

__test__ = False # Avoid collecting TestCollectionTorch as a pytest test

float = TestCollection(
dev_data_formats=[
torch.float16,
torch.float32,
# torch.float64,
torch.bfloat16,
],
)

int = TestCollection(
dev_data_formats=[
torch.int8,
# torch.int16,
torch.int32,
torch.int64,
# torch.uint8,
],
)

bool = TestCollection(
dev_data_formats=[
torch.bool,
],
)

all = TestCollection(dev_data_formats=float.dev_data_formats + int.dev_data_formats)

single = TestCollection(
dev_data_formats=[
torch.float16,
],
)
Loading