From ffd6604049a0ffff612b1632aa6133e533caf094 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Sun, 15 Sep 2024 18:39:08 -0400 Subject: [PATCH] Temporarily gate against BF16. --- lib/nnc/mfa/v2/AttentionDescriptor.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/nnc/mfa/v2/AttentionDescriptor.cpp b/lib/nnc/mfa/v2/AttentionDescriptor.cpp index e4da836b0..32d8f825f 100644 --- a/lib/nnc/mfa/v2/AttentionDescriptor.cpp +++ b/lib/nnc/mfa/v2/AttentionDescriptor.cpp @@ -197,7 +197,7 @@ AttentionOperands AttentionDescriptor::createMemoryPrecisi memoryPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP16; memoryPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP16; memoryPrecisions[AttentionOperand::V] = GEMMOperandPrecision::FP16; - memoryPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::BF16; + memoryPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP32; // GEMMOperandPrecision::BF16; } else { memoryPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP32; memoryPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP32; @@ -263,7 +263,7 @@ AttentionOperands AttentionDescriptor::createMemoryPrecisi // unrolled (head dimension vastly exceeds head block dimension). if (lowPrecisionIntermediates) { memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16; - memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::BF16; + memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; // GEMMOperandPrecision::BF16; } else { memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32; memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; @@ -340,7 +340,7 @@ AttentionOperands AttentionDescriptor::createRegisterPreci registerPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP16; registerPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP16; registerPrecisions[AttentionOperand::V] = GEMMOperandPrecision::FP16; - registerPrecisions[AttentionOperand::dO] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32; + registerPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32; } else { registerPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP32; @@ -351,7 +351,7 @@ AttentionOperands AttentionDescriptor::createRegisterPreci // The register precision of L/D only counts for backward key-value. if (lowPrecisionIntermediates) { registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16; - registerPrecisions[AttentionOperand::D] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32; + registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32; } else { registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; @@ -378,7 +378,7 @@ AttentionOperands AttentionDescriptor::createRegisterPreci registerPrecisions[AttentionOperand::S] = lowPrecisionInputs ? GEMMOperandPrecision::FP16 : GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP16; registerPrecisions[AttentionOperand::dP] = GEMMOperandPrecision::FP32; - registerPrecisions[AttentionOperand::dS] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32; + registerPrecisions[AttentionOperand::dS] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32; } else { registerPrecisions[AttentionOperand::S] = GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP32;