-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add repeat_interleave operator test plan [skip ci]
- Loading branch information
1 parent
626ca63
commit f459d7f
Showing
1 changed file
with
166 additions
and
0 deletions.
There are no files selected for viewing
166 changes: 166 additions & 0 deletions
166
forge/test/operators/pytorch/tm/test_repeat_interleave.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import forge | ||
import math | ||
import torch | ||
import pytest | ||
import random | ||
import os | ||
|
||
from typing import List, Dict | ||
from loguru import logger | ||
|
||
from forge.verify.config import VerifyConfig | ||
|
||
from forge.verify.value_checkers import AllCloseValueChecker | ||
from forge.verify.verify import verify as forge_verify | ||
|
||
from test.operators.utils import InputSourceFlags, VerifyUtils | ||
from test.operators.utils import InputSource | ||
from test.operators.utils import TestVector | ||
from test.operators.utils import TestPlan | ||
from test.operators.utils import TestPlanUtils | ||
from test.operators.utils import FailingReasons | ||
from test.operators.utils.compat import TestDevice | ||
from test.operators.utils import TestCollection | ||
from test.operators.utils import TestCollectionCommon | ||
from test.operators.utils import ValueRanges | ||
|
||
from test.operators.pytorch.eltwise_unary import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass | ||
|
||
|
||
class TestVerification: | ||
|
||
MODEL_TYPES = { | ||
InputSource.FROM_ANOTHER_OP: ModelFromAnotherOp, | ||
InputSource.FROM_HOST: ModelDirect, | ||
InputSource.FROM_DRAM_QUEUE: ModelDirect, | ||
InputSource.CONST_EVAL_PASS: ModelConstEvalPass, | ||
} | ||
|
||
@classmethod | ||
def verify( | ||
cls, | ||
test_device: TestDevice, | ||
test_vector: TestVector, | ||
input_params: List[Dict] = [], | ||
warm_reset: bool = False, | ||
): | ||
|
||
input_source_flag: InputSourceFlags = None | ||
if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,): | ||
input_source_flag = InputSourceFlags.FROM_DRAM | ||
|
||
operator = getattr(torch, test_vector.operator) | ||
kwargs = test_vector.kwargs if test_vector.kwargs else {} | ||
|
||
model_type = cls.MODEL_TYPES[test_vector.input_source] | ||
pytorch_model = ( | ||
model_type(operator, test_vector.input_shape, kwargs) | ||
if test_vector.input_source in (InputSource.CONST_EVAL_PASS,) | ||
else model_type(operator, kwargs) | ||
) | ||
|
||
input_shapes = tuple([test_vector.input_shape]) | ||
|
||
logger.trace(f"***input_shapes: {input_shapes}") | ||
|
||
VerifyUtils.verify( | ||
model=pytorch_model, | ||
test_device=test_device, | ||
input_shapes=input_shapes, | ||
input_params=input_params, | ||
input_source_flag=input_source_flag, | ||
dev_data_format=test_vector.dev_data_format, | ||
math_fidelity=test_vector.math_fidelity, | ||
warm_reset=warm_reset, | ||
value_range=ValueRanges.SMALL, | ||
deprecated_verification=False, | ||
verify_config=VerifyConfig(value_checker=AllCloseValueChecker()), | ||
) | ||
|
||
|
||
class TestParamsData: | ||
|
||
__test__ = False | ||
|
||
test_plan: TestPlan = None | ||
|
||
operator = ["repeat_interleave"] | ||
|
||
specific_cases = { | ||
# input_shape: [(repeats, dim)...] | ||
(1, 1, 1, 58): [(1, 0), (1, 1), (58, 2)], | ||
} | ||
|
||
@classmethod | ||
def generate_kwargs(cls, test_vector: TestVector): | ||
|
||
rng = random.Random(math.prod(test_vector.input_shape)) | ||
|
||
yield { | ||
# repeats is only int values, tensor is not supported yet | ||
"repeats": rng.randint(1, 10), | ||
"dim": rng.choice([None] + list(range(len(test_vector.input_shape)))), | ||
} | ||
|
||
@classmethod | ||
def generate_specific_kwargs(cls, test_vector: TestVector): | ||
|
||
for repeats, dim in cls.specific_cases[test_vector.input_shape]: | ||
yield { | ||
"repeats": repeats, | ||
"dim": dim, | ||
} | ||
|
||
|
||
TestParamsData.test_plan = TestPlan( | ||
verify=lambda test_device, test_vector: TestVerification.verify( | ||
test_device, | ||
test_vector, | ||
), | ||
collections=[ | ||
# Test operators with all shapes and input sources collection: | ||
TestCollection( | ||
operators=TestParamsData.operator, | ||
input_sources=TestCollectionCommon.all.input_sources, | ||
input_shapes=TestCollectionCommon.all.input_shapes, | ||
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), | ||
), | ||
# Test Data formats collection: | ||
TestCollection( | ||
operators=TestParamsData.operator, | ||
input_sources=TestCollectionCommon.single.input_sources, | ||
input_shapes=TestCollectionCommon.single.input_shapes, | ||
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), | ||
dev_data_formats=[ | ||
item | ||
for item in TestCollectionCommon.all.dev_data_formats | ||
if item not in TestCollectionCommon.single.dev_data_formats | ||
] | ||
), | ||
# Test math fidelity collection: | ||
TestCollection( | ||
operators=TestParamsData.operator, | ||
input_sources=TestCollectionCommon.single.input_sources, | ||
input_shapes=TestCollectionCommon.single.input_shapes, | ||
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), | ||
dev_data_formats=TestCollectionCommon.single.dev_data_formats, | ||
math_fidelities=TestCollectionCommon.all.math_fidelities, | ||
), | ||
# Test specific cases collection: | ||
TestCollection( | ||
operators=TestParamsData.operator, | ||
input_sources=TestCollectionCommon.all.input_sources, | ||
input_shapes=TestParamsData.specific_cases.keys(), | ||
kwargs=lambda test_vector: TestParamsData.generate_specific_kwargs(test_vector), | ||
), | ||
], | ||
failing_rules=[], | ||
) | ||
|
||
|
||
def get_test_plans() -> List[TestPlan]: | ||
return [TestParamsData.test_plan] |