Skip to content

Commit

Permalink
Add flag to ignore unsupported @triton.autotune args in user-written …
Browse files Browse the repository at this point in the history
…kernel compilation (pytorch#131431)

Summary: We currently don't support some of the `@triton.autotune` arguments when compiling user-written Triton kernels with PT2. In this PR, we're adding a flag to circumvent it. This is to unblock internal compilation in some cases. The flag is supplied with the docs mentioning why it is not a good idea to set it.

Test Plan:

```
python test/inductor/test_triton_kernels.py -k test_triton_kernel_
autotune_with_unsupported_args
...
----------------------------------------------------------------------
Ran 3 tests in 3.636s

OK
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#131431
Approved by: https://github.com/oulgen, https://github.com/zou3519
  • Loading branch information
aakhundov authored and pytorchmergebot committed Jul 24, 2024
1 parent eafbd20 commit e9db1b0
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 18 deletions.
23 changes: 23 additions & 0 deletions test/inductor/test_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,29 @@ def grid_fn(meta):
output2 = torch.zeros_like(t1, requires_grad=grad)
self.assertEqual(compiled_func(t1, t2, output2), torch_add)

@requires_gpu
@skipIfRocm # https://github.com/pytorch/pytorch/actions/runs/10051552819/job/27782048305?pr=131431
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@patch.object(
torch._inductor.config, "unsafe_ignore_unsupported_triton_autotune_args", True
)
def test_triton_kernel_autotune_with_unsupported_args(self, backend):
def call_triton(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
n_elements = output.numel()
add_kernel_autotuned_with_unsupported_args[(n_elements,)](
x, y, output, n_elements
)
return output

t1 = torch.rand(256, device=GPU_TYPE)
t2 = torch.rand(256, device=GPU_TYPE)

torch_add = call_triton(t1, t2)
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
compiled_add = compiled_func(t1, t2)
self.assertEqual(compiled_add, torch_add)

@requires_gpu
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
Expand Down
43 changes: 25 additions & 18 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,26 +809,33 @@ def init_variable(self, variable, kernel, kernel_idx, grid):
# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
# The call to get_first_attr is to maintain backward-compatibility.
if (
(
"warmup" in defaults
and defaults["warmup"].default
!= torch._dynamo.utils.get_first_attr(
kernel, "num_warmups", "warmup"
not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args
and (
(
"warmup" in defaults
and defaults["warmup"].default
!= torch._dynamo.utils.get_first_attr(
kernel, "num_warmups", "warmup"
)
)
or (
"rep" in defaults
and defaults["rep"].default
!= torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep")
)
or (
"prune_configs_by" in defaults
and defaults["prune_configs_by"].default
!= kernel.early_config_prune
)
# Set via reset_to_zero argument
or len(kernel.reset_idx) != 0
or len(kernel.restore_idx) != 0
or (
"use_cuda_graph" in defaults
and defaults["use_cuda_graph"].default != kernel.use_cuda_graph
)
)
or (
"rep" in defaults
and defaults["rep"].default
!= torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep")
)
or (
"prune_configs_by" in defaults
and defaults["prune_configs_by"].default
!= kernel.early_config_prune
)
# Set via reset_to_zero argument
or len(kernel.reset_idx) != 0
or len(kernel.restore_idx) != 0
):
self.raise_unsupported(
"Only configs and keys are supported for triton.autotune"
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,12 @@ def decide_compile_threads():
# In the common case, most inputs will be aligned.
assume_aligned_inputs: bool = False

# For the user-written Triton kernels compiled with the model, ignore the unsupported
# arguments passed to the @triton.autotune in the user's code; this is unsafe, as
# ignoring the unsupported args may lead to unexpected autotuning behavior: don't
# set unless you know what you're doing.
unsafe_ignore_unsupported_triton_autotune_args: bool = False


# config specific to codegen/cpp.py
class cpp:
Expand Down
27 changes: 27 additions & 0 deletions torch/testing/_internal/triton_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,33 @@ def add_kernel_2d_autotuned(
tmp2 = tmp0 + tmp1
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)

@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
warmup=10,
rep=20,
prune_configs_by={"early_config_prune": lambda configs, *_, **__: configs},
)
@triton.jit
def add_kernel_autotuned_with_unsupported_args(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

@triton.jit
def add_kernel_with_scaling(
in_ptr0,
Expand Down

0 comments on commit e9db1b0

Please sign in to comment.