From 61a97c32f64641738d2cc623708f28046768224e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 29 Jul 2024 21:26:07 -0400 Subject: [PATCH] [Kernel] Fix marlin divide-by-zero warnings (#6904) --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 61 ++++++++++--------- .../marlin/dense/marlin_cuda_kernel.cu | 18 ++++-- .../marlin/sparse/marlin_24_cuda_kernel.cu | 18 ++++-- 3 files changed, 58 insertions(+), 39 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 36ae2bfafa7c2..26cc248e6ac5d 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1128,44 +1128,47 @@ __global__ void Marlin( }; auto fetch_zp_to_registers = [&](int k, int full_pipe) { - if constexpr (!has_zp) { - return; - } + if constexpr (has_zp) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(group_blocks != 0); - int pipe = full_pipe % stages; + int pipe = full_pipe % stages; - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - sh_zp_stage += cur_group_id * zp_sh_stride; + sh_zp_stage += cur_group_id * zp_sh_stride; - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } } } }; diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 37339b84ae25b..efbcc182a3ae4 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -452,10 +452,15 @@ __global__ void Marlin( B_ptr[i] += b_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group - if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; + if constexpr (group_blocks != -1) { + // This assumes group_blocks >= thread_k_blocks + // and would need to be modified to support smaller groups. + static_assert(group_blocks >= thread_k_blocks); + if (pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } } } // Insert a fence even when we are winding down the pipeline to ensure that @@ -480,7 +485,10 @@ __global__ void Marlin( // however, this does not seem to be a significant bottleneck, while some // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. - if (group_blocks != -1) { + if constexpr (group_blocks != -1) { + // This assumes group_blocks >= thread_k_blocks + // and would need to be modified to support smaller groups. + static_assert(group_blocks >= thread_k_blocks); int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index b5effc3055441..3c50f1786bc68 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -404,10 +404,15 @@ __global__ void Marlin_24( meta_ptr[i] += m_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group - if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; + if constexpr (group_blocks != -1) { + // This assumes group_blocks >= thread_k_blocks + // and would need to be modified to support smaller groups. + static_assert(group_blocks >= thread_k_blocks); + if (pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } } } // Insert a fence even when we are winding down the pipeline to ensure that @@ -432,7 +437,10 @@ __global__ void Marlin_24( // however, this does not seem to be a significant bottleneck, while some // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. - if (group_blocks != -1) { + if constexpr (group_blocks != -1) { + // This assumes group_blocks >= thread_k_blocks + // and would need to be modified to support smaller groups. + static_assert(group_blocks >= thread_k_blocks); int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));