forked from ggerganov/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from JohannesGaessler/cuda-fa-no-tc-11
Cuda fa no tc 11
- Loading branch information
Showing
9 changed files
with
899 additions
and
458 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#define FATTN_KQ_STRIDE 256 | ||
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. | ||
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. | ||
|
||
template<int D, int parallel_blocks> // D == head size | ||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) | ||
__launch_bounds__(D, 1) | ||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) | ||
static __global__ void flash_attn_combine_results( | ||
const float * __restrict__ VKQ_parts, | ||
const float2 * __restrict__ VKQ_meta, | ||
float * __restrict__ dst) { | ||
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; | ||
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; | ||
dst += D * gridDim.y*blockIdx.x; | ||
|
||
const int tid = threadIdx.x; | ||
__builtin_assume(tid < D); | ||
|
||
__shared__ float2 meta[parallel_blocks]; | ||
if (tid < 2*parallel_blocks) { | ||
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
float kqmax = meta[0].x; | ||
#pragma unroll | ||
for (int l = 1; l < parallel_blocks; ++l) { | ||
kqmax = max(kqmax, meta[l].x); | ||
} | ||
|
||
float VKQ_numerator = 0.0f; | ||
float VKQ_denominator = 0.0f; | ||
#pragma unroll | ||
for (int l = 0; l < parallel_blocks; ++l) { | ||
const float diff = meta[l].x - kqmax; | ||
const float KQ_max_scale = expf(diff); | ||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); | ||
*((uint32_t *) &KQ_max_scale) &= ftz_mask; | ||
|
||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; | ||
VKQ_denominator += KQ_max_scale * meta[l].y; | ||
} | ||
|
||
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; | ||
} |
Oops, something went wrong.