diff --git a/include/dtype.h b/include/dtype.h index de72bcee..c1220366 100644 --- a/include/dtype.h +++ b/include/dtype.h @@ -51,4 +51,8 @@ enum ActivationType { SILU, }; +enum RopeType { + LLAMA_ROPE = 0, +}; + } // namespace xft diff --git a/include/layers_decoder.h b/include/layers_decoder.h index 34f6aa52..969a9ac5 100644 --- a/include/layers_decoder.h +++ b/include/layers_decoder.h @@ -18,12 +18,13 @@ namespace xft { -void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim, - int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, - int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride, - const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight, - const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta, - const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias = nullptr, - const float *keyBias = nullptr, const float *valueBias = nullptr, const float *attnOutBias = nullptr); +void invokeLayerLLaMA(DataType dt, DataType kvcdt, RopeType rt, ActivationType at, NormType nt, int batchSize, + int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, + int pastSeqLen, int currentSeqLen, int step, int hiddenSize, int intermediateSize, void *output, + int outputStride, const void *input, int inputStride, const float *ln1Gamma, const float *ln1Beta, + const void *queryWeight, const void *keyWeight, const void *valueWeight, const void *attnOutWeight, + const float *ln2Gamma, const float *ln2Beta, const void *gateWeight, const void *upWeight, + const void *downWeight, const float *queryBias = nullptr, const float *keyBias = nullptr, + const float *valueBias = nullptr, const float *attnOutBias = nullptr); -} // namespace xft \ No newline at end of file +} // namespace xft diff --git a/src/layers/decoder_layer.cpp b/src/layers/decoder_layer.cpp index 02f13cbf..60d15818 100644 --- a/src/layers/decoder_layer.cpp +++ b/src/layers/decoder_layer.cpp @@ -21,19 +21,21 @@ #include "layers_mlp.h" #include "mlp_llama.h" #include "rms_norm.h" +#include "numa_allocator.h" #include namespace xft { -template +template void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride, const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight, const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta, const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias, - const float *keyBias, const float *valueBias, const float *attnOutBias) { + const float *keyBias, const float *valueBias, const float *attnOutBias, + MMHelper *mmHelper, DecoderContext *ctx, KVCacheManager *kvCacheMgr) { // TODO: will deprecate attention mask in future, so need to change this auto prepareAttnMask = [&](DecoderContext *ctx, int step) { @@ -83,71 +85,52 @@ void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, return mask; }; - using DECODER = Decoder, LlamaMLP>; - static std::unordered_map llama_layer_hub; - static MMHelper *mmHelper; - static DecoderContext *ctx; - static KVCacheManager *kvCacheMgr; - - std::string actType; - if (at == ActivationType::SILU) - actType = "silu"; - else if (at == ActivationType::RELU) - actType = "relu"; - else if (at == ActivationType::GELU) - actType = "gelu"; - else if (at == ActivationType::SWIGLU) - actType = "swiglu"; - else - printf(">> unsupported activation type\n"); - - if (ctx == nullptr - || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) { - if (ctx != nullptr) delete ctx; - printf(">> create context: %d %d\n", hiddenSize, intermediateSize); - mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); - ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, intermediateSize, actType, 1e-6, 0, - 0, maxPositions, maxPosEmbed, -1, 0, 1, mmHelper); - if (kvCacheMgr != nullptr) delete kvCacheMgr; - kvCacheMgr = new KVCacheManager(1); - } - + using DECODER = Decoder, LlamaMLP>; + DECODER *llama_layer; + xft::Matrix *actBuffers = new xft::Matrix; + //static std::unordered_map llama_layer_hub; + static std::unordered_map*>> llama_layer_hub; // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. std::stringstream weights_addr; weights_addr << queryWeight << "_" << keyWeight << "_" << valueWeight << "_" << attnOutWeight << "_" << gateWeight << "_" << upWeight << "_" << downWeight << "_" << dt << "_" << at << "_" << nt << "_" << attHeadDim << "_" << attHeadNum << "_" << kvHeadNum; std::string llama_layer_key = weights_addr.str(); - DECODER *llama_layer; auto it_created = llama_layer_hub.find(llama_layer_key); if (it_created == llama_layer_hub.end()) { + int firstNode = getenv("FIRST_TOKEN_WEIGHT_LOCATION") ? atoi(getenv("FIRST_TOKEN_WEIGHT_LOCATION")) : -1; + int nextNode = getenv("NEXT_TOKEN_WEIGHT_LOCATION") ? atoi(getenv("NEXT_TOKEN_WEIGHT_LOCATION")) : -1; + if (step == 0) + xft_set_preferred_node(firstNode); + else + xft_set_preferred_node(nextNode); llama_layer = new DECODER(ctx, 0); llama_layer->setWeights(ctx, (const float *)queryWeight, nullptr, nullptr, queryBias, (const float *)keyWeight, nullptr, nullptr, keyBias, (const float *)valueWeight, nullptr, nullptr, valueBias, (const float *)attnOutWeight, nullptr, nullptr, attnOutBias, ln1Gamma, ln1Beta, (const float *)gateWeight, nullptr, nullptr, nullptr, (const float *)upWeight, nullptr, nullptr, nullptr, ln2Gamma, ln2Beta, (const float *)downWeight, nullptr, nullptr, false); - llama_layer_hub[llama_layer_key] = llama_layer; + + actBuffers->Resize(batchSize * inputSeqLen * 2, hiddenSize); + + llama_layer_hub[llama_layer_key] = std::make_tuple(llama_layer, actBuffers);; printf(">> create llama_layer_key: %s\n", llama_layer_key.c_str()); + xft_set_preferred_node(-1); } else { - llama_layer = it_created->second; + llama_layer = std::get<0>(it_created->second); + actBuffers = std::get<1>(it_created->second); } ctx->resize(batchSize, inputSeqLen, pastSeqLen); - xft::Matrix actBuffers; - actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize); float *attnMask = prepareAttnMask(ctx, step); - int workers = 1; - int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers; - kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim); - KVCacheTensor &presentKey = kvCacheMgr->getKey(0); - KVCacheTensor &presentValue = kvCacheMgr->getValue(0); + KVCacheTensor &presentKey = kvCacheMgr->getKey(0); + KVCacheTensor &presentValue = kvCacheMgr->getValue(0); float *attnOut = (float *)(ctx->tmpBuf.Data()); - llama_layer->forwardAttention(ctx, (float *)input, actBuffers.Data(), attnOut, attnMask, + llama_layer->forwardAttention(ctx, (float *)input, actBuffers->Data(), attnOut, attnMask, presentKey, // presentKey, presentValue, // presentValue, inputSeqLen, // inputSeqLen, @@ -159,7 +142,8 @@ void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, llama_layer->forwardFFN(ctx, attnOut, (float *)output, inputStride, outputStride, true); } -void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim, +template +void LayerLLaMAWrapper(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride, const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight, @@ -169,35 +153,121 @@ void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize static std::mutex mutex; std::lock_guard lock(mutex); + std::string actType; + if (at == ActivationType::SILU) + actType = "silu"; + else if (at == ActivationType::RELU) + actType = "relu"; + else if (at == ActivationType::GELU) + actType = "gelu"; + else if (at == ActivationType::SWIGLU) + actType = "swiglu"; + else { + printf(">> unsupported activation type\n"); + return; + } + + static MMHelper *mmHelper; + static DecoderContext *ctx; + if (ctx == nullptr + || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) { + if (ctx != nullptr) delete ctx; + printf(">> create context: %d %d\n", hiddenSize, intermediateSize); + mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, intermediateSize, actType, 1e-6, 0, + 0, maxPositions, maxPosEmbed, -1, 0, 1, mmHelper); + } + + KVCacheManager *kvCacheMgr; + static std::unordered_map *> kv_hub; + + // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. + std::stringstream layer_key; + layer_key << queryWeight << "_" << keyWeight << "_" << valueWeight << "_" << attnOutWeight << "_" << gateWeight + << "_" << upWeight << "_" << downWeight << "_" << dt << "_" << at << "_" << nt << "_" << attHeadDim + << "_" << attHeadNum << "_" << kvHeadNum; + std::string kv_hub_key = layer_key.str(); + + auto it_created = kv_hub.find(kv_hub_key); + if (it_created == kv_hub.end()) { + int kvcNode = getenv("KVCACHE_LOCATION") ? atoi(getenv("KVCACHE_LOCATION")) : -1; + xft_set_preferred_node(kvcNode); + kvCacheMgr = new KVCacheManager(1); + int workers = 1; + int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers; + kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim); + kv_hub[kv_hub_key] = kvCacheMgr; + printf(">> create kv_hub_key: %s\n", kv_hub_key.c_str()); + xft_set_preferred_node(-1); + } else { + kvCacheMgr = it_created->second; + } + if (dt == DataType::bf16) { - if (nt == NormType::RMS) - LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + if (nt == NormType::RMS) { + LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, - attnOutBias); - else if (nt == NormType::LN) { - LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + attnOutBias, mmHelper, ctx, kvCacheMgr); + } else if (nt == NormType::LN) { + LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, - attnOutBias); + attnOutBias, mmHelper, ctx, kvCacheMgr); } else { printf(">> unsupported norm type\n"); } } else if (dt == DataType::fp16) { - if (nt == NormType::RMS) - LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + if (nt == NormType::RMS) { + LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, - attnOutBias); - else if (nt == NormType::LN) { - LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + attnOutBias, mmHelper, ctx, kvCacheMgr); + } else if (nt == NormType::LN) { + LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, - attnOutBias); + attnOutBias, mmHelper, ctx, kvCacheMgr); + } else { + printf(">> unsupported norm type\n"); + } + } else if (dt == DataType::bf16_int8) { + if (nt == NormType::RMS) { + auto firstTokenFunc = LayerLLaMAImpl; + auto nextTokenFunc = LayerLLaMAImpl; + if (step == 0) { + firstTokenFunc(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, + outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, + attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, + attnOutBias, mmHelper, ctx, kvCacheMgr); + + } else { + nextTokenFunc(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, + outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, + attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, + attnOutBias, mmHelper, ctx, kvCacheMgr); + } + } else if (nt == NormType::LN) { + auto firstTokenFunc = LayerLLaMAImpl; + auto nextTokenFunc = LayerLLaMAImpl; + if (step == 0) + firstTokenFunc(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, + outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, + attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, + attnOutBias, mmHelper, ctx, kvCacheMgr); + else + nextTokenFunc(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, + outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, + attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, + attnOutBias, mmHelper, ctx, kvCacheMgr); } else { printf(">> unsupported norm type\n"); } @@ -206,4 +276,39 @@ void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize } } +void invokeLayerLLaMA(DataType dt, DataType kvcdt, RopeType rt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim, + int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, + int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride, + const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight, + const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta, + const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias, + const float *keyBias, const float *valueBias, const float *attnOutBias) { + + if (kvcdt == DataType::fp16) { + if (rt == RopeType::LLAMA_ROPE) + return LayerLLaMAWrapper(dt, at, nt, batchSize, inputSeqLen, attHeadDim, + attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, + hiddenSize, intermediateSize, output, outputStride, input, inputStride, + ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, attnOutWeight, ln2Gamma, ln2Beta, + gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, attnOutBias) ; + else { + printf(">> unsupported Rope type: %d\n", rt); + } + } else if (kvcdt == DataType::int8) { + if (rt == RopeType::LLAMA_ROPE) + return LayerLLaMAWrapper(dt, at, nt, batchSize, inputSeqLen, attHeadDim, + attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, + hiddenSize, intermediateSize, output, outputStride, input, inputStride, + ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, attnOutWeight, ln2Gamma, ln2Beta, + gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, attnOutBias) ; + else { + printf(">> unsupported Rope type: %d\n", rt); + } + } else { + printf(">> unsupported KVcache data type: %d\n", kvcdt); + return; + } + +} + } // namespace xft diff --git a/tests/ut/layers_decoder_test.cpp b/tests/ut/layers_decoder_test.cpp index be75d941..f008dc9f 100644 --- a/tests/ut/layers_decoder_test.cpp +++ b/tests/ut/layers_decoder_test.cpp @@ -21,8 +21,8 @@ #include "layers_decoder.h" #include "gtest/gtest.h" -template -static void compareLayerLLaMA(int step, int batchSize, int inputSeqLen, int pastSeqLen, int currentSeqLen, +static void compareLayerLLaMA(xft::DataType dt, xft::DataType kvcdt, int step, + int batchSize, int inputSeqLen, int pastSeqLen, int currentSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int hiddenSize, int intermediateSize, const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight, const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, @@ -36,19 +36,8 @@ static void compareLayerLLaMA(int step, int batchSize, int inputSeqLen, int past input[i] = static_cast(1.0f * rand() / RAND_MAX); } - xft::DataType dt = xft::DataType::unknown; - if constexpr (std::is_same::value) { - dt = xft::DataType::bf16; - } else if constexpr (std::is_same::value) { - dt = xft::DataType::fp16; - } else { - printf("Unsupported data type\n"); - GTEST_FAIL(); - return; - } - auto start = std::chrono::high_resolution_clock::now(); - invokeLayerLLaMA(dt, xft::ActivationType::SILU, xft::NormType::RMS, batchSize, inputSeqLen, attHeadDim, attHeadNum, + invokeLayerLLaMA(dt, kvcdt, xft::RopeType::LLAMA_ROPE, xft::ActivationType::SILU, xft::NormType::RMS, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, (void *)ourOutput, hiddenSize, input, hiddenSize, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, attnOutWeight, ln2Gamma, ln2Beta, gateW, upW, downW); @@ -60,8 +49,7 @@ static void compareLayerLLaMA(int step, int batchSize, int inputSeqLen, int past free(ourOutput); } -template -void test_LayerLLaMA(void) { +void test_LayerLLaMA(xft::DataType dt, xft::DataType kvcdt) { int maxPosEmbed = 4096; int maxPositions = maxPosEmbed; int hiddenSize = 4096; @@ -111,16 +99,16 @@ void test_LayerLLaMA(void) { int currentSeqLen = inputSeqLen; int nextTokenNum = 1; - compareLayerLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, kvHeadNum, + compareLayerLLaMA(dt, kvcdt, step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, intermediateSize, ln1Gamma, ln1Beta, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj, ln2Gamma, ln2Beta, gateW, upW, downW); pastSeqLen += inputSeqLen; currentSeqLen = nextTokenNum; - compareLayerLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, kvHeadNum, + compareLayerLLaMA(dt, kvcdt, step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, intermediateSize, ln1Gamma, ln1Beta, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj, ln2Gamma, ln2Beta, gateW, upW, downW); pastSeqLen += nextTokenNum; - compareLayerLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, kvHeadNum, + compareLayerLLaMA(dt, kvcdt, step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, intermediateSize, ln1Gamma, ln1Beta, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj, ln2Gamma, ln2Beta, gateW, upW, downW); @@ -135,15 +123,31 @@ void test_LayerLLaMA(void) { free(downW); } -TEST(LayerLLaMA, bfloat16_t) { - test_LayerLLaMA(); +TEST(LayerLLaMA, w_bf16_kv_fp16) { + test_LayerLLaMA(xft::DataType::bf16, xft::DataType::fp16); +} + +TEST(LayerLLaMA, w_bf16_kv_int8) { + test_LayerLLaMA(xft::DataType::bf16, xft::DataType::int8); +} + +TEST(LayerLLaMA, w_fp16_kv_fp16) { + test_LayerLLaMA(xft::DataType::fp16, xft::DataType::fp16); } -TEST(LayerLLaMA, float16_t) { - test_LayerLLaMA(); +TEST(LayerLLaMA, w_fp16_kv_int8) { + test_LayerLLaMA(xft::DataType::fp16, xft::DataType::int8); +} + +TEST(LayerLLaMA, w_bf16_int8_kv_fp16) { + test_LayerLLaMA(xft::DataType::bf16_int8, xft::DataType::fp16); +} + +TEST(LayerLLaMA, w_bf16_int8_kv_int8) { + test_LayerLLaMA(xft::DataType::bf16_int8, xft::DataType::int8); } int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} \ No newline at end of file +}