Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bf16_int8 support for invokeLayerLLaMA API #470

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,8 @@ enum ActivationType {
SILU,
};

enum RopeType {
LLAMA_ROPE = 0,
};

} // namespace xft
17 changes: 9 additions & 8 deletions include/layers_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

namespace xft {

void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim,
int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step,
int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride,
const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight,
const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta,
const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias = nullptr,
const float *keyBias = nullptr, const float *valueBias = nullptr, const float *attnOutBias = nullptr);
void invokeLayerLLaMA(DataType dt, DataType kvcdt, RopeType rt, ActivationType at, NormType nt, int batchSize,
int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed,
int pastSeqLen, int currentSeqLen, int step, int hiddenSize, int intermediateSize, void *output,
int outputStride, const void *input, int inputStride, const float *ln1Gamma, const float *ln1Beta,
const void *queryWeight, const void *keyWeight, const void *valueWeight, const void *attnOutWeight,
const float *ln2Gamma, const float *ln2Beta, const void *gateWeight, const void *upWeight,
const void *downWeight, const float *queryBias = nullptr, const float *keyBias = nullptr,
const float *valueBias = nullptr, const float *attnOutBias = nullptr);

} // namespace xft
} // namespace xft
215 changes: 160 additions & 55 deletions src/layers/decoder_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@
#include "layers_mlp.h"
#include "mlp_llama.h"
#include "rms_norm.h"
#include "numa_allocator.h"

#include <unordered_map>

namespace xft {

template <typename DataT, typename NormT>
template <typename DataT, typename KVCacheT, typename RopeT, typename NormT>
void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim,
int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step,
int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride,
const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight,
const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta,
const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias,
const float *keyBias, const float *valueBias, const float *attnOutBias) {
const float *keyBias, const float *valueBias, const float *attnOutBias,
MMHelper *mmHelper, DecoderContext *ctx, KVCacheManager<KVCacheT> *kvCacheMgr) {

// TODO: will deprecate attention mask in future, so need to change this
auto prepareAttnMask = [&](DecoderContext *ctx, int step) {
Expand Down Expand Up @@ -83,71 +85,52 @@ void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize,
return mask;
};

using DECODER = Decoder<Attention<DataT, LlamaRotaryEmbedding, NormT>, LlamaMLP<DataT>>;
static std::unordered_map<std::string, DECODER *> llama_layer_hub;
static MMHelper *mmHelper;
static DecoderContext *ctx;
static KVCacheManager<float16_t> *kvCacheMgr;

std::string actType;
if (at == ActivationType::SILU)
actType = "silu";
else if (at == ActivationType::RELU)
actType = "relu";
else if (at == ActivationType::GELU)
actType = "gelu";
else if (at == ActivationType::SWIGLU)
actType = "swiglu";
else
printf(">> unsupported activation type\n");

if (ctx == nullptr
|| (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) {
if (ctx != nullptr) delete ctx;
printf(">> create context: %d %d\n", hiddenSize, intermediateSize);
mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex());
ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, intermediateSize, actType, 1e-6, 0,
0, maxPositions, maxPosEmbed, -1, 0, 1, mmHelper);
if (kvCacheMgr != nullptr) delete kvCacheMgr;
kvCacheMgr = new KVCacheManager<float16_t>(1);
}

using DECODER = Decoder<Attention<DataT, RopeT, NormT>, LlamaMLP<DataT>>;
DECODER *llama_layer;
xft::Matrix<float> *actBuffers = new xft::Matrix<float>;
//static std::unordered_map<std::string, DECODER *> llama_layer_hub;
static std::unordered_map<std::string, std::tuple<DECODER*, xft::Matrix<float>*>> llama_layer_hub;
// create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed.
std::stringstream weights_addr;
weights_addr << queryWeight << "_" << keyWeight << "_" << valueWeight << "_" << attnOutWeight << "_" << gateWeight
<< "_" << upWeight << "_" << downWeight << "_" << dt << "_" << at << "_" << nt << "_" << attHeadDim
<< "_" << attHeadNum << "_" << kvHeadNum;
std::string llama_layer_key = weights_addr.str();
DECODER *llama_layer;

auto it_created = llama_layer_hub.find(llama_layer_key);
if (it_created == llama_layer_hub.end()) {
int firstNode = getenv("FIRST_TOKEN_WEIGHT_LOCATION") ? atoi(getenv("FIRST_TOKEN_WEIGHT_LOCATION")) : -1;
int nextNode = getenv("NEXT_TOKEN_WEIGHT_LOCATION") ? atoi(getenv("NEXT_TOKEN_WEIGHT_LOCATION")) : -1;
if (step == 0)
xft_set_preferred_node(firstNode);
else
xft_set_preferred_node(nextNode);
llama_layer = new DECODER(ctx, 0);
llama_layer->setWeights(ctx, (const float *)queryWeight, nullptr, nullptr, queryBias, (const float *)keyWeight,
nullptr, nullptr, keyBias, (const float *)valueWeight, nullptr, nullptr, valueBias,
(const float *)attnOutWeight, nullptr, nullptr, attnOutBias, ln1Gamma, ln1Beta,
(const float *)gateWeight, nullptr, nullptr, nullptr, (const float *)upWeight, nullptr, nullptr,
nullptr, ln2Gamma, ln2Beta, (const float *)downWeight, nullptr, nullptr, false);
llama_layer_hub[llama_layer_key] = llama_layer;

actBuffers->Resize(batchSize * inputSeqLen * 2, hiddenSize);

llama_layer_hub[llama_layer_key] = std::make_tuple(llama_layer, actBuffers);;
printf(">> create llama_layer_key: %s\n", llama_layer_key.c_str());
xft_set_preferred_node(-1);
} else {
llama_layer = it_created->second;
llama_layer = std::get<0>(it_created->second);
actBuffers = std::get<1>(it_created->second);
}

ctx->resize(batchSize, inputSeqLen, pastSeqLen);
xft::Matrix<float> actBuffers;
actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize);
float *attnMask = prepareAttnMask(ctx, step);

int workers = 1;
int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers;
kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim);
KVCacheTensor<float16_t> &presentKey = kvCacheMgr->getKey(0);
KVCacheTensor<float16_t> &presentValue = kvCacheMgr->getValue(0);
KVCacheTensor<KVCacheT> &presentKey = kvCacheMgr->getKey(0);
KVCacheTensor<KVCacheT> &presentValue = kvCacheMgr->getValue(0);

float *attnOut = (float *)(ctx->tmpBuf.Data());

llama_layer->forwardAttention(ctx, (float *)input, actBuffers.Data(), attnOut, attnMask,
llama_layer->forwardAttention(ctx, (float *)input, actBuffers->Data(), attnOut, attnMask,
presentKey, // presentKey,
presentValue, // presentValue,
inputSeqLen, // inputSeqLen,
Expand All @@ -159,7 +142,8 @@ void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize,
llama_layer->forwardFFN(ctx, attnOut, (float *)output, inputStride, outputStride, true);
}

void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim,
template <typename KVCacheT, typename RopeT>
void LayerLLaMAWrapper(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim,
int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step,
int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride,
const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight,
Expand All @@ -169,35 +153,121 @@ void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);

std::string actType;
if (at == ActivationType::SILU)
actType = "silu";
else if (at == ActivationType::RELU)
actType = "relu";
else if (at == ActivationType::GELU)
actType = "gelu";
else if (at == ActivationType::SWIGLU)
actType = "swiglu";
else {
printf(">> unsupported activation type\n");
return;
}

static MMHelper *mmHelper;
static DecoderContext *ctx;
if (ctx == nullptr
|| (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) {
if (ctx != nullptr) delete ctx;
printf(">> create context: %d %d\n", hiddenSize, intermediateSize);
mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex());
ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, intermediateSize, actType, 1e-6, 0,
0, maxPositions, maxPosEmbed, -1, 0, 1, mmHelper);
}

KVCacheManager<KVCacheT> *kvCacheMgr;
static std::unordered_map<std::string, KVCacheManager<KVCacheT> *> kv_hub;

// create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed.
std::stringstream layer_key;
layer_key << queryWeight << "_" << keyWeight << "_" << valueWeight << "_" << attnOutWeight << "_" << gateWeight
<< "_" << upWeight << "_" << downWeight << "_" << dt << "_" << at << "_" << nt << "_" << attHeadDim
<< "_" << attHeadNum << "_" << kvHeadNum;
std::string kv_hub_key = layer_key.str();

auto it_created = kv_hub.find(kv_hub_key);
if (it_created == kv_hub.end()) {
int kvcNode = getenv("KVCACHE_LOCATION") ? atoi(getenv("KVCACHE_LOCATION")) : -1;
xft_set_preferred_node(kvcNode);
kvCacheMgr = new KVCacheManager<KVCacheT>(1);
int workers = 1;
int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers;
kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim);
kv_hub[kv_hub_key] = kvCacheMgr;
printf(">> create kv_hub_key: %s\n", kv_hub_key.c_str());
xft_set_preferred_node(-1);
} else {
kvCacheMgr = it_created->second;
}

if (dt == DataType::bf16) {
if (nt == NormType::RMS)
LayerLLaMAImpl<bfloat16_t, RmsNorm>(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
if (nt == NormType::RMS) {
LayerLLaMAImpl<bfloat16_t, KVCacheT, RopeT, RmsNorm>(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output,
outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight,
attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias,
attnOutBias);
else if (nt == NormType::LN) {
LayerLLaMAImpl<bfloat16_t, LayerNorm>(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
attnOutBias, mmHelper, ctx, kvCacheMgr);
} else if (nt == NormType::LN) {
LayerLLaMAImpl<bfloat16_t, KVCacheT, RopeT, LayerNorm>(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output,
outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight,
attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias,
attnOutBias);
attnOutBias, mmHelper, ctx, kvCacheMgr);
} else {
printf(">> unsupported norm type\n");
}
} else if (dt == DataType::fp16) {
if (nt == NormType::RMS)
LayerLLaMAImpl<float16_t, RmsNorm>(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
if (nt == NormType::RMS) {
LayerLLaMAImpl<float16_t, KVCacheT, RopeT, RmsNorm>(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output,
outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight,
attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias,
attnOutBias);
else if (nt == NormType::LN) {
LayerLLaMAImpl<float16_t, LayerNorm>(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
attnOutBias, mmHelper, ctx, kvCacheMgr);
} else if (nt == NormType::LN) {
LayerLLaMAImpl<float16_t, KVCacheT, RopeT, LayerNorm>(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output,
outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight,
attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias,
attnOutBias);
attnOutBias, mmHelper, ctx, kvCacheMgr);
} else {
printf(">> unsupported norm type\n");
}
} else if (dt == DataType::bf16_int8) {
if (nt == NormType::RMS) {
auto firstTokenFunc = LayerLLaMAImpl<bfloat16_t, KVCacheT, RopeT, RmsNorm>;
auto nextTokenFunc = LayerLLaMAImpl<int8_t, KVCacheT, RopeT, RmsNorm>;
if (step == 0) {
firstTokenFunc(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output,
outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight,
attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias,
attnOutBias, mmHelper, ctx, kvCacheMgr);

} else {
nextTokenFunc(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output,
outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight,
attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias,
attnOutBias, mmHelper, ctx, kvCacheMgr);
}
} else if (nt == NormType::LN) {
auto firstTokenFunc = LayerLLaMAImpl<bfloat16_t, KVCacheT, RopeT, LayerNorm>;
auto nextTokenFunc = LayerLLaMAImpl<int8_t, KVCacheT, RopeT, LayerNorm>;
if (step == 0)
firstTokenFunc(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output,
outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight,
attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias,
attnOutBias, mmHelper, ctx, kvCacheMgr);
else
nextTokenFunc(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum,
maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output,
outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight,
attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias,
attnOutBias, mmHelper, ctx, kvCacheMgr);
} else {
printf(">> unsupported norm type\n");
}
Expand All @@ -206,4 +276,39 @@ void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize
}
}

void invokeLayerLLaMA(DataType dt, DataType kvcdt, RopeType rt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim,
int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step,
int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride,
const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight,
const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta,
const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias,
const float *keyBias, const float *valueBias, const float *attnOutBias) {

if (kvcdt == DataType::fp16) {
if (rt == RopeType::LLAMA_ROPE)
return LayerLLaMAWrapper<float16_t, LlamaRotaryEmbedding>(dt, at, nt, batchSize, inputSeqLen, attHeadDim,
attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step,
hiddenSize, intermediateSize, output, outputStride, input, inputStride,
ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, attnOutWeight, ln2Gamma, ln2Beta,
gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, attnOutBias) ;
else {
printf(">> unsupported Rope type: %d\n", rt);
}
} else if (kvcdt == DataType::int8) {
if (rt == RopeType::LLAMA_ROPE)
return LayerLLaMAWrapper<int8_t, LlamaRotaryEmbedding>(dt, at, nt, batchSize, inputSeqLen, attHeadDim,
attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step,
hiddenSize, intermediateSize, output, outputStride, input, inputStride,
ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, attnOutWeight, ln2Gamma, ln2Beta,
gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, attnOutBias) ;
else {
printf(">> unsupported Rope type: %d\n", rt);
}
} else {
printf(">> unsupported KVcache data type: %d\n", kvcdt);
return;
}

}

} // namespace xft
Loading
Loading