Skip to content

Commit

Permalink
Add a final fix for MFAv2 where we compute the gid ourselves.
Browse files Browse the repository at this point in the history
Some minor code fix.
  • Loading branch information
liuliu committed Dec 20, 2024
1 parent 05c73e6 commit 85efcf8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
12 changes: 3 additions & 9 deletions lib/nnc/mfa/ccv_nnc_mfa_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
}

MTL::Size gridSize
(ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]),
hash.Hq,
attentionDesc.batchDimension);
(ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1);
MTL::Size groupSize
(int64_t(kernel->threadgroupSize), 1, 1);

Expand Down Expand Up @@ -239,9 +237,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
}

MTL::Size backwardQueryGridSize
(ceilDivide(int64_t(hash.R), backwardQueryKernel->blockDimensions[0]),
hash.Hq,
attentionDesc.batchDimension);
(ceilDivide(int64_t(hash.R), backwardQueryKernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1);
MTL::Size backwardQueryGroupSize
(int64_t(backwardQueryKernel->threadgroupSize), 1, 1);

Expand Down Expand Up @@ -286,9 +282,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
}

MTL::Size backwardKeyValueGridSize
(ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]),
hash.Hq,
attentionDesc.batchDimension);
(ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1);
MTL::Size backwardKeyValueGroupSize
(int64_t(backwardKeyValueKernel->threadgroupSize), 1, 1);

Expand Down
12 changes: 12 additions & 0 deletions lib/nnc/mfa/v2/AttentionKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,17 @@ std::string AttentionKernel::createSource() const noexcept {
kernel void attention(
)";
source += createBufferBindings() + "\n";
switch (type.value) {
case AttentionKernelType::forward:
source.SetValue("DISPATCH_DIMENSION", "R");
break;
case AttentionKernelType::backwardQuery:
source.SetValue("DISPATCH_DIMENSION", "R");
break;
case AttentionKernelType::backwardKeyValue:
source.SetValue("DISPATCH_DIMENSION", "C");
break;
}
source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0]));
source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue());
source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue());
Expand All @@ -438,6 +449,7 @@ std::string AttentionKernel::createSource() const noexcept {
ushort lane_id [[thread_index_in_simdgroup]]
) {
ushort2 morton_offset = morton_order(lane_id);
gid = { gid.x % (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}}), (gid.x / (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}})) % Hq, gid.x / (Hq * (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}}))};
uint parallelization_group_offset = gid.x;
parallelization_group_offset *= {{BLOCK_DIMENSIONS_PARALLELIZATION}};
Expand Down

0 comments on commit 85efcf8

Please sign in to comment.