Skip to content

Commit

Permalink
[AMD] Improved error estimate for the matmul example
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Jan 10, 2025
1 parent 2b41842 commit a59dbfa
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,6 @@ def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"


def is_hip_mi200():
target = triton.runtime.driver.active.get_current_target()
return target.backend == 'hip' and target.arch == 'gfx90a'


def get_cuda_autotune_config():
return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
Expand Down Expand Up @@ -231,6 +226,15 @@ def get_autotune_config():
return get_hip_autotune_config()


def estimate_matmul_error(a, b):
c = torch.matmul(a, b).to(torch.float64)
a_double = a.to(torch.float64)
b_double = b.to(torch.float64)
c_accurate = torch.matmul(a_double, b_double)
forward_error = torch.norm(c_accurate - c)
return forward_error


# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
Expand Down Expand Up @@ -363,11 +367,11 @@ def matmul(a, b, activation=""):
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
# Bigger tolerance for AMD MI200 devices.
# MI200 devices use reduced precision fp16 and bf16 and flush input and
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
rtol = 1e-2 if is_hip_mi200() else 0
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):

forward_error = estimate_matmul_error(a, b)
compute_error = torch.norm(triton_output.to(torch.float64) - torch_output.to(torch.float64))

if compute_error < forward_error:
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
Expand Down

0 comments on commit a59dbfa

Please sign in to comment.