diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/rotary_position_embedding_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/rotary_position_embedding_kernel.cc index 558a4d308..f2c99279b 100644 --- a/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/rotary_position_embedding_kernel.cc +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/rotary_position_embedding_kernel.cc @@ -83,11 +83,6 @@ ppl::common::RetCode DynamicBatchingRotaryPositionEmbeddingKernel::DoExecute(Ker return ppl::common::RC_UNSUPPORTED; } - if (param_->scaling_type != param_->SCALING_TYPE_NONE) { - LOG(ERROR) << "currently only support scaling_type == ''"; - return ppl::common::RC_UNSUPPORTED; - } - int64_t max_seqlen_val = 0; if (ppl::common::RC_SUCCESS != max_seqlen->CopyToHost(&max_seqlen_val)) { LOG(ERROR) << "max_seqlen->CopyToHost() failed"; @@ -100,6 +95,18 @@ ppl::common::RetCode DynamicBatchingRotaryPositionEmbeddingKernel::DoExecute(Ker int64_t num_heads = query_shape->GetDim(1); int64_t num_key_heads = key_shape->GetDim(1); + ppl::kernel::llm::cuda::pmx::rope_scaling_t scaling_type; + if (param_->scaling_type != param_->SCALING_TYPE_NONE) { + scaling_type = ppl::kernel::llm::cuda::pmx::rope_scaling::NONE; + } else if (param_->scaling_type != param_->SCALING_TYPE_LINEAR) { + scaling_type = ppl::kernel::llm::cuda::pmx::rope_scaling::LINEAR; + } else if (param_->scaling_type != param_->SCALING_TYPE_DYNAMIC) { + scaling_type = ppl::kernel::llm::cuda::pmx::rope_scaling::DYNAMIC; + } else { + LOG(ERROR) << "invalid scaling type: " << param_->scaling_type; + return ppl::common::RC_INVALID_VALUE; + } + return ppl::kernel::llm::cuda::pmx::dynamic_batching_rotary_position_embedding( GetStream(), query_shape, @@ -115,6 +122,9 @@ ppl::common::RetCode DynamicBatchingRotaryPositionEmbeddingKernel::DoExecute(Ker num_heads, num_key_heads, max_seqlen_val, + param_->max_position_embeddings, + scaling_type, + param_->scaling_factor, rotated_query->GetShape(), rotated_query->GetBufferPtr(), rotated_key->GetShape(),