Skip to content

Commit

Permalink
Temporarily gate against BF16.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 15, 2024
1 parent 2134068 commit ffd6604
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions lib/nnc/mfa/v2/AttentionDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createMemoryPrecisi
memoryPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP16;
memoryPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP16;
memoryPrecisions[AttentionOperand::V] = GEMMOperandPrecision::FP16;
memoryPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::BF16;
memoryPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP32; // GEMMOperandPrecision::BF16;
} else {
memoryPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP32;
memoryPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP32;
Expand Down Expand Up @@ -263,7 +263,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createMemoryPrecisi
// unrolled (head dimension vastly exceeds head block dimension).
if (lowPrecisionIntermediates) {
memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::BF16;
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; // GEMMOperandPrecision::BF16;
} else {
memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
Expand Down Expand Up @@ -340,7 +340,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
registerPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::V] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::dO] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
} else {
registerPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP32;
Expand All @@ -351,7 +351,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
// The register precision of L/D only counts for backward key-value.
if (lowPrecisionIntermediates) {
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::D] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
} else {
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
Expand All @@ -378,7 +378,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
registerPrecisions[AttentionOperand::S] = lowPrecisionInputs ? GEMMOperandPrecision::FP16 : GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::dP] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dS] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dS] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
} else {
registerPrecisions[AttentionOperand::S] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP32;
Expand Down

0 comments on commit ffd6604

Please sign in to comment.