From e9db1b059733a02e1fb726d22a0489471044ad98 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Tue, 23 Jul 2024 13:04:14 -0700 Subject: [PATCH] Add flag to ignore unsupported @triton.autotune args in user-written kernel compilation (#131431) Summary: We currently don't support some of the `@triton.autotune` arguments when compiling user-written Triton kernels with PT2. In this PR, we're adding a flag to circumvent it. This is to unblock internal compilation in some cases. The flag is supplied with the docs mentioning why it is not a good idea to set it. Test Plan: ``` python test/inductor/test_triton_kernels.py -k test_triton_kernel_ autotune_with_unsupported_args ... ---------------------------------------------------------------------- Ran 3 tests in 3.636s OK ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/131431 Approved by: https://github.com/oulgen, https://github.com/zou3519 --- test/inductor/test_triton_kernels.py | 23 ++++++++++ torch/_higher_order_ops/triton_kernel_wrap.py | 43 +++++++++++-------- torch/_inductor/config.py | 6 +++ torch/testing/_internal/triton_utils.py | 27 ++++++++++++ 4 files changed, 81 insertions(+), 18 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index c4eb6a849ee35e..aa8dab6298bdbf 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -672,6 +672,29 @@ def grid_fn(meta): output2 = torch.zeros_like(t1, requires_grad=grad) self.assertEqual(compiled_func(t1, t2, output2), torch_add) + @requires_gpu + @skipIfRocm # https://github.com/pytorch/pytorch/actions/runs/10051552819/job/27782048305?pr=131431 + @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) + @patch.object( + torch._inductor.config, "unsafe_ignore_unsupported_triton_autotune_args", True + ) + def test_triton_kernel_autotune_with_unsupported_args(self, backend): + def call_triton(x: torch.Tensor, y: torch.Tensor): + output = torch.zeros_like(x) + n_elements = output.numel() + add_kernel_autotuned_with_unsupported_args[(n_elements,)]( + x, y, output, n_elements + ) + return output + + t1 = torch.rand(256, device=GPU_TYPE) + t2 = torch.rand(256, device=GPU_TYPE) + + torch_add = call_triton(t1, t2) + compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True) + compiled_add = compiled_func(t1, t2) + self.assertEqual(compiled_add, torch_add) + @requires_gpu @common_utils.parametrize("grad", [False, True]) @common_utils.parametrize("dynamic", [False, True]) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 99706003d96f54..602bdd6fa2a198 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -809,26 +809,33 @@ def init_variable(self, variable, kernel, kernel_idx, grid): # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep. # The call to get_first_attr is to maintain backward-compatibility. if ( - ( - "warmup" in defaults - and defaults["warmup"].default - != torch._dynamo.utils.get_first_attr( - kernel, "num_warmups", "warmup" + not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args + and ( + ( + "warmup" in defaults + and defaults["warmup"].default + != torch._dynamo.utils.get_first_attr( + kernel, "num_warmups", "warmup" + ) + ) + or ( + "rep" in defaults + and defaults["rep"].default + != torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep") + ) + or ( + "prune_configs_by" in defaults + and defaults["prune_configs_by"].default + != kernel.early_config_prune + ) + # Set via reset_to_zero argument + or len(kernel.reset_idx) != 0 + or len(kernel.restore_idx) != 0 + or ( + "use_cuda_graph" in defaults + and defaults["use_cuda_graph"].default != kernel.use_cuda_graph ) ) - or ( - "rep" in defaults - and defaults["rep"].default - != torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep") - ) - or ( - "prune_configs_by" in defaults - and defaults["prune_configs_by"].default - != kernel.early_config_prune - ) - # Set via reset_to_zero argument - or len(kernel.reset_idx) != 0 - or len(kernel.restore_idx) != 0 ): self.raise_unsupported( "Only configs and keys are supported for triton.autotune" diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 6bb3ab529c0585..1efd4a1371f02b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -626,6 +626,12 @@ def decide_compile_threads(): # In the common case, most inputs will be aligned. assume_aligned_inputs: bool = False +# For the user-written Triton kernels compiled with the model, ignore the unsupported +# arguments passed to the @triton.autotune in the user's code; this is unsafe, as +# ignoring the unsupported args may lead to unexpected autotuning behavior: don't +# set unless you know what you're doing. +unsafe_ignore_unsupported_triton_autotune_args: bool = False + # config specific to codegen/cpp.py class cpp: diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index ab54e8b8bec99a..f06a11217edcf8 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -117,6 +117,33 @@ def add_kernel_2d_autotuned( tmp2 = tmp0 + tmp1 tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask) + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), + ], + key=[], + warmup=10, + rep=20, + prune_configs_by={"early_config_prune": lambda configs, *_, **__: configs}, + ) + @triton.jit + def add_kernel_autotuned_with_unsupported_args( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + @triton.jit def add_kernel_with_scaling( in_ptr0,