diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 7b838b17a279..327a2470752b 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -161,11 +161,6 @@ def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" -def is_hip_mi200(): - target = triton.runtime.driver.active.get_current_target() - return target.backend == 'hip' and target.arch == 'gfx90a' - - def get_cuda_autotune_config(): return [ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, @@ -231,6 +226,15 @@ def get_autotune_config(): return get_hip_autotune_config() +def estimate_matmul_error(a, b): + c = torch.matmul(a, b).to(torch.float64) + a_double = a.to(torch.float64) + b_double = b.to(torch.float64) + c_accurate = torch.matmul(a_double, b_double) + forward_error = torch.norm(c_accurate - c) + return forward_error + + # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: # - A list of `triton.Config` objects that define different configurations of # meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try @@ -363,11 +367,11 @@ def matmul(a, b, activation=""): torch_output = torch.matmul(a, b) print(f"triton_output_with_fp16_inputs={triton_output}") print(f"torch_output_with_fp16_inputs={torch_output}") -# Bigger tolerance for AMD MI200 devices. -# MI200 devices use reduced precision fp16 and bf16 and flush input and -# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices -rtol = 1e-2 if is_hip_mi200() else 0 -if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + +forward_error = estimate_matmul_error(a, b) +compute_error = torch.norm(triton_output.to(torch.float64) - torch_output.to(torch.float64)) + +if compute_error < forward_error: print("✅ Triton and Torch match") else: print("❌ Triton and Torch differ")