Skip to content

Commit

Permalink
[Kernel] Fix marlin divide-by-zero warnings (#6904)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored Jul 30, 2024
1 parent 4fbf4aa commit 61a97c3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 39 deletions.
61 changes: 32 additions & 29 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1128,44 +1128,47 @@ __global__ void Marlin(
};

auto fetch_zp_to_registers = [&](int k, int full_pipe) {
if constexpr (!has_zp) {
return;
}
if constexpr (has_zp) {
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(group_blocks != 0);

int pipe = full_pipe % stages;
int pipe = full_pipe % stages;

if constexpr (group_blocks == -1) {
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
if constexpr (group_blocks == -1) {
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}

} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;

int warp_row = warp_id / n_warps;
int warp_row = warp_id / n_warps;

int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);

int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;

int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;

sh_zp_stage += cur_group_id * zp_sh_stride;
sh_zp_stage += cur_group_id * zp_sh_stride;

for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
}
};
Expand Down
18 changes: 13 additions & 5 deletions csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,15 @@ __global__ void Marlin(
B_ptr[i] += b_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
Expand All @@ -480,7 +485,10 @@ __global__ void Marlin(
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) {
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
Expand Down
18 changes: 13 additions & 5 deletions csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,15 @@ __global__ void Marlin_24(
meta_ptr[i] += m_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
Expand All @@ -432,7 +437,10 @@ __global__ void Marlin_24(
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) {
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
Expand Down

0 comments on commit 61a97c3

Please sign in to comment.