Skip to content

Commit

Permalink
support torch dtype as data formats in test plan
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrkicTT committed Jan 16, 2025
1 parent b96eb45 commit 6aba283
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 13 deletions.
4 changes: 1 addition & 3 deletions forge/test/operators/pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def log_test_vector_properties(item: _pytest.python.Function, report: _pytest.re
item.user_properties.append(
("input_source", test_vector.input_source.name if test_vector.input_source is not None else None)
)
item.user_properties.append(
("dev_data_format", test_vector.dev_data_format.name if test_vector.dev_data_format is not None else None)
)
item.user_properties.append(("dev_data_format", TestPlanUtils.dev_data_format_to_str(test_vector.dev_data_format)))
item.user_properties.append(
("math_fidelity", test_vector.math_fidelity.name if test_vector.math_fidelity is not None else None)
)
Expand Down
2 changes: 1 addition & 1 deletion forge/test/operators/pytorch/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def build_filtered_collection(cls) -> TestCollection:
dev_data_formats = os.getenv("DEV_DATA_FORMATS", None)
if dev_data_formats:
dev_data_formats = dev_data_formats.split(",")
dev_data_formats = [getattr(forge.DataFormat, dev_data_format) for dev_data_format in dev_data_formats]
dev_data_formats = [TestPlanUtils.dev_data_format_from_str(dev_data_format) for dev_data_format in dev_data_formats]

math_fidelities = os.getenv("MATH_FIDELITIES", None)
if math_fidelities:
Expand Down
4 changes: 4 additions & 0 deletions forge/test/operators/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .datatypes import OperatorParameterTypes
from .datatypes import ValueRange
from .datatypes import ValueRanges
from .datatypes import DataTypeTestPlan
from .utils import ShapeUtils
from .utils import InputSourceFlag, InputSourceFlags
from .utils import CompilerUtils
Expand All @@ -24,6 +25,7 @@
from .plan import FailingRulesConverter
from .plan import TestPlanScanner
from .test_data import TestCollectionCommon
from .test_data import TestCollectionTorch
from .failing_reasons import FailingReasons
from .failing_reasons import FailingReasonsValidation
from .pytest import PyTestUtils
Expand All @@ -33,6 +35,7 @@
"OperatorParameterTypes",
"ValueRange",
"ValueRanges",
"DataTypeTestPlan",
"ShapeUtils",
"InputSourceFlag",
"InputSourceFlags",
Expand All @@ -53,6 +56,7 @@
"FailingRulesConverter",
"TestPlanScanner",
"TestCollectionCommon",
"TestCollectionTorch",
"FailingReasons",
"FailingReasonsValidation",
"PyTestUtils",
Expand Down
21 changes: 18 additions & 3 deletions forge/test/operators/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from forge.verify.compare import compare_with_golden

from .datatypes import OperatorParameterTypes, ValueRanges, ValueRange
from .datatypes import DataTypeTestPlan


# TODO - Remove this class once TestDevice is available in Forge
Expand Down Expand Up @@ -89,10 +90,12 @@ class TestTensorsUtils:
torch.bfloat16: (-10000, 10000),
torch.float16: (-10000, 10000),
torch.float32: (-10000, 10000),
torch.float64: (-10000, 10000),
torch.uint8: (0, 2**8 - 1),
torch.int8: (-(2**7), 2**7 - 1),
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1),
torch.int64: (-(2**63), 2**63 - 1),
}

class DTypes:
Expand All @@ -102,12 +105,14 @@ class DTypes:
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
)
integers = (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
booleans = (torch.bool,)

Expand Down Expand Up @@ -170,6 +175,8 @@ def get_dtype_for_df(cls, dev_data_format: forge.DataFormat = None) -> torch.dty

if dev_data_format is None:
dtype = None
elif isinstance(dev_data_format, torch.dtype):
dtype = dev_data_format
else:
# dtype = torch.float32
if dev_data_format in cls.dev_data_format_to_dtype:
Expand Down Expand Up @@ -226,7 +233,7 @@ def verify_module(
model: Module,
input_shapes: List[TensorShape],
pcc: Optional[float] = None,
dev_data_format: forge.DataFormat = None,
dev_data_format: DataTypeTestPlan = None,
value_range: Optional[Union[ValueRanges, ValueRange, OperatorParameterTypes.RangeValue]] = None,
random_seed: int = 42,
):
Expand All @@ -243,7 +250,7 @@ def verify_module(
# TODO move to class TestTensorsUtils
def create_torch_inputs(
input_shapes: List[TensorShape],
dev_data_format: forge.DataFormat = None,
dev_data_format: DataTypeTestPlan = None,
value_range: Optional[Union[ValueRanges, ValueRange, OperatorParameterTypes.RangeValue]] = None,
random_seed: Optional[int] = None,
) -> List[torch.Tensor]:
Expand All @@ -254,6 +261,8 @@ def create_torch_inputs(
generator = torch.Generator().manual_seed(random_seed)

dtype = TestTensorsUtils.get_dtype_for_df(dev_data_format)
# dev_data_format is no more used
dev_data_format = None

# if dtype is not None:
# torch.set_default_dtype(dtype)
Expand All @@ -280,7 +289,13 @@ def verify_module_for_inputs(

fw_out = model(*inputs)

forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
# TODO introduce flag to check framework of forge data formats

if dev_data_format is None or isinstance(dev_data_format, forge.DataFormat):
forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
elif isinstance(dev_data_format, torch.dtype):
forge_inputs = inputs
# forge_inputs = inputs

compiled_model = forge.compile(model, sample_inputs=forge_inputs)
co_out = compiled_model(*forge_inputs)
Expand Down
6 changes: 6 additions & 0 deletions forge/test/operators/utils/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from enum import Enum
from typing import Optional, Dict, Union, Tuple, TypeAlias

import torch
import forge


DataTypeTestPlan = Union[forge.DataFormat, torch.dtype]


class OperatorParameterTypes:
SingleValue: TypeAlias = Union[int, float]
Expand Down
39 changes: 33 additions & 6 deletions forge/test/operators/utils/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from forge.op_repo import TensorShape

from .datatypes import OperatorParameterTypes
from .datatypes import DataTypeTestPlan
from .pytest import PytestParamsUtils
from .compat import TestDevice

Expand Down Expand Up @@ -84,7 +85,7 @@ class TestVector:
input_source: InputSource
input_shape: TensorShape # TODO - Support multiple input shapes
number_of_operands: Optional[int] = None
dev_data_format: Optional[DataFormat] = None
dev_data_format: Optional[DataTypeTestPlan] = None
math_fidelity: Optional[MathFidelity] = None
kwargs: Optional[OperatorParameterTypes.Kwargs] = None
pcc: Optional[float] = None
Expand All @@ -94,7 +95,7 @@ class TestVector:
def get_id(self, fields: Optional[List[str]] = None) -> str:
"""Get test vector id"""
if fields is None:
return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}{'-' + str(self.number_of_operands) + '-' if self.number_of_operands else '-'}{self.dev_data_format.name if self.dev_data_format else None}-{self.math_fidelity.name if self.math_fidelity else None}"
return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}{'-' + str(self.number_of_operands) + '-' if self.number_of_operands else '-'}{TestPlanUtils.dev_data_format_to_str(self.dev_data_format)}-{self.math_fidelity.name if self.math_fidelity else None}"
else:
attr = [
(getattr(self, field).name if getattr(self, field) is not None else None)
Expand Down Expand Up @@ -142,7 +143,7 @@ class TestCollection:
input_sources: Optional[List[InputSource]] = None
input_shapes: Optional[List[TensorShape]] = None # TODO - Support multiple input shapes
numbers_of_operands: Optional[List[int]] = None
dev_data_formats: Optional[List[DataFormat]] = None
dev_data_formats: Optional[List[DataTypeTestPlan]] = None
math_fidelities: Optional[List[MathFidelity]] = None
kwargs: Optional[
Union[List[OperatorParameterTypes.Kwargs], Callable[["TestVector"], List[OperatorParameterTypes.Kwargs]]]
Expand Down Expand Up @@ -494,6 +495,32 @@ class TestPlanUtils:
Utility functions for test vectors
"""

@classmethod
def dev_data_format_to_str(cls, dev_data_format: DataTypeTestPlan) -> Optional[str]:
"""Convert data format to string"""
if dev_data_format is None:
return None
if isinstance(dev_data_format, DataFormat):
return dev_data_format.name
if isinstance(dev_data_format, torch.dtype):
# Remove torch. prefix
return str(dev_data_format).split('.')[-1]
else:
raise ValueError(f"Unsupported data format: {dev_data_format}")

@classmethod
def dev_data_format_from_str(cls, dev_data_format_str: str) -> DataTypeTestPlan:
"""Convert string to data format"""
if dev_data_format_str is None:
return None
if hasattr(forge.DataFormat, dev_data_format_str):
dev_data_format = getattr(forge.DataFormat, dev_data_format_str)
elif hasattr(torch, dev_data_format_str):
dev_data_format = getattr(torch, dev_data_format_str)
else:
raise ValueError(f"Unsupported data format: {dev_data_format_str} in Forge and PyTorch")
return dev_data_format

@classmethod
def _match(cls, rule_collection: Optional[List], vector_value):
"""
Expand Down Expand Up @@ -624,7 +651,7 @@ def test_id_to_test_vector(cls, test_id: str) -> TestVector:
dev_data_format_part = parts[dev_data_format_index]
if dev_data_format_part == "None":
dev_data_format_part = None
dev_data_format = eval(f"forge._C.{dev_data_format_part}") if dev_data_format_part is not None else None
dev_data_format = cls.dev_data_format_from_str(dev_data_format_part)

math_fidelity_part = parts[math_fidelity_index]
if math_fidelity_part == "None":
Expand Down Expand Up @@ -663,7 +690,7 @@ def build_rules(
Union[Optional[InputSource], List[InputSource]],
Union[Optional[TensorShape], List[TensorShape]],
Union[Optional[OperatorParameterTypes.Kwargs], List[OperatorParameterTypes.Kwargs]],
Union[Optional[forge.DataFormat], List[forge.DataFormat]],
Union[Optional[DataTypeTestPlan], List[DataTypeTestPlan]],
Union[Optional[forge.MathFidelity], List[forge.MathFidelity]],
Optional[TestResultFailing],
],
Expand Down Expand Up @@ -708,7 +735,7 @@ def build_rule(
input_source: Optional[Union[InputSource, List[InputSource]]],
input_shape: Optional[Union[TensorShape, List[TensorShape]]],
kwargs: Optional[Union[OperatorParameterTypes.Kwargs, List[OperatorParameterTypes.Kwargs]]],
dev_data_format: Optional[Union[forge.DataFormat, List[forge.DataFormat]]],
dev_data_format: Optional[Union[DataTypeTestPlan, List[DataTypeTestPlan]]],
math_fidelity: Optional[Union[forge.MathFidelity, List[forge.MathFidelity]]],
result_failing: Optional[TestResultFailing],
) -> TestCollection:
Expand Down
42 changes: 42 additions & 0 deletions forge/test/operators/utils/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import forge
import torch

from forge import MathFidelity, DataFormat

Expand Down Expand Up @@ -309,3 +310,44 @@ class TestCollectionCommon:
(14, 13, 89, 3), # 4.2 Prime numbers
]
)


class TestCollectionTorch:
"""
Shared test collection for torch data types.
"""

__test__ = False # Avoid collecting TestCollectionTorch as a pytest test

float = TestCollection(
dev_data_formats=[
torch.float16,
torch.float32,
# torch.float64,
torch.bfloat16,
],
)

int = TestCollection(
dev_data_formats=[
torch.int8,
# torch.int16,
torch.int32,
torch.int64,
# torch.uint8,
],
)

bool = TestCollection(
dev_data_formats=[
torch.bool,
],
)

all = TestCollection(dev_data_formats=float.dev_data_formats + int.dev_data_formats)

single = TestCollection(
dev_data_formats=[
torch.float16,
],
)

0 comments on commit 6aba283

Please sign in to comment.