Skip to content

Commit

Permalink
[feature] RMS-Norm fp16 support skip_term=False
Browse files Browse the repository at this point in the history
  • Loading branch information
yimmmin authored and Alcanderian committed Jul 15, 2024
1 parent dc9a5b4 commit 1b04f8f
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions src/ppl/nn/engines/llm_cuda/kernels/opmx/rms_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) {
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [weight]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(weight);

void *skip_in_data = nullptr;
void *skip_in_ptr = nullptr;
if (skip_in) {
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [skip_in]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(skip_in);
skip_in_data = skip_in->GetBufferPtr();
skip_in_ptr = skip_in->GetBufferPtr();
}

PPLNN_LLM_CUDA_DEBUG_TRACE("eps: %f\n", param_->eps);
Expand All @@ -52,11 +52,6 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) {

auto input_shape = input->GetShape();

if (param_->skip_term == false) {
LOG(ERROR) << "currently only support skip_term == true.";
return ppl::common::RC_UNSUPPORTED;
}

if (param_->axis != -1 && param_->axis != input_shape->GetDim(input_shape->GetDimCount() - 1)) {
LOG(ERROR) << "currently only support axis == -1 or input's last dim.";
return ppl::common::RC_UNSUPPORTED;
Expand All @@ -74,6 +69,7 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) {
PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output);

void *skip_out_ptr = nullptr;
if (skip_out) {
if (can_trans_skip_in) {
skip_out->TransferBufferFrom(skip_in);
Expand All @@ -82,6 +78,7 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) {
}
PPLNN_LLM_CUDA_DEBUG_TRACE("Output [skip_out]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(skip_out);
skip_out_ptr = skip_out->GetBufferPtr();
}

if (param_->skip_term && !skip_out) {
Expand All @@ -94,17 +91,22 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) {
return ppl::common::RC_UNSUPPORTED;
}

return ppl::kernel::llm::cuda::pmx::rms_norm(
const int64_t dim_count = input_shape->GetDimCount();
const int64_t real_axis = param_->axis > 0 ? param_->axis : (param_->axis + dim_count);

const int64_t batch = input_shape->CalcElementsToDimensionIncludingPadding(real_axis);
const int64_t norm_dim = input_shape->CalcElementsFromDimensionIncludingPadding(real_axis);

return ppl::kernel::llm::cuda::pmx::rms_norm_fp16(
GetStream(),
input_shape,
input_data,
weight->GetBufferPtr(),
skip_in_data,
param_->axis,
skip_in_ptr,
param_->eps,
param_->skip_term,
output->GetBufferPtr(),
skip_out->GetBufferPtr()
batch,
norm_dim,
skip_out_ptr,
output->GetBufferPtr()
);
}

Expand Down

0 comments on commit 1b04f8f

Please sign in to comment.