Skip to content

Commit

Permalink
ggml-cuda : add TQ2_0 kernels, for ternary inference on GPU ggergano…
Browse files Browse the repository at this point in the history
  • Loading branch information
Nexesenex committed Jan 10, 2025
1 parent 7d91690 commit 79cde2c
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ typedef sycl::half2 ggml_half2;
#define QI6_K (QK_K / (4*QR6_K))
#define QR6_K 2

#define QI2_0 (QK_K / (4*QR2_0))
#define QR2_0 4

#define QI2_XXS (QK_K / (4*QR2_XXS))
#define QR2_XXS 4

Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
static constexpr int qi = QI6_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_TQ2_0> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_0;
static constexpr int qi = QI2_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
static constexpr int qk = QK_K;
Expand Down
31 changes: 31 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,26 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
}

template<typename dst_t>
static __global__ void dequantize_block_tq2_0(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int64_t i = blockIdx.x;
const block_tq2_0 * x = (const block_tq2_0 *) vx;

const int64_t tid = threadIdx.x; // 0..64
const int64_t n = tid/32; // 0 or 1
const int64_t l = tid - 32*n; // 0..32

const uint8_t q = x[i].qs[32*n + l];
dst_t * y = yy + i*QK_K + 128*n;

float d = __half2float(x[i].d);
y[l+ 0] = d * ((q >> 0) & 3) - d;
y[l+32] = d * ((q >> 2) & 3) - d;
y[l+64] = d * ((q >> 4) & 3) - d;
y[l+96] = d * ((q >> 6) & 3) - d;
}

template<typename dst_t>
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {

Expand Down Expand Up @@ -1008,6 +1028,13 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_tq2_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_tq2_0<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
Expand Down Expand Up @@ -1268,6 +1295,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_TQ2_0:
return dequantize_row_tq2_0_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_KT:
Expand Down Expand Up @@ -1345,6 +1374,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_TQ2_0:
return dequantize_row_tq2_0_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_KT:
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3112,6 +3112,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ2_S:
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_Q6_K:
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
break;
case GGML_TYPE_TQ2_0:
mul_mat_q_case<GGML_TYPE_TQ2_0>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
break;
Expand Down Expand Up @@ -119,6 +122,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
Expand Down
74 changes: 74 additions & 0 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_Q5_K:
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
Expand Down Expand Up @@ -165,6 +166,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
Expand Down Expand Up @@ -200,6 +202,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
Expand Down Expand Up @@ -1876,6 +1879,68 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
#endif // INT8_MMA_AVAILABLE
}

template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_tile + 2*WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE

const int kqsx = threadIdx.x % QI2_0;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_0) {
int i = i0 + threadIdx.y*(WARP_SIZE/QI2_0) + threadIdx.x/QI2_0;

if (need_check) {
i = min(i, i_max);
}

const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
const int qs0 = get_int_b2(bxi->qs, kqsx);

#pragma unroll
for (int l = 0; l < QR2_0; ++l) {
// 0..7, 32..39
// 8..15, 40..47
// 16..23, 48..55
// 24..31, 56..63
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);

#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q;
#else
x_qs[i*(2*WARP_SIZE + 1) + k] = q;
#endif // INT8_MMA_AVAILABLE
}
}

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_0/2)) {
int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_0) + threadIdx.x/(QI2_0/2);

if (need_check) {
i = min(i, i_max);
}

const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;

const int k = threadIdx.x % (QI2_0/2);

#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d;
#else
x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d;
#endif // INT8_MMA_AVAILABLE
}
}

template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

Expand Down Expand Up @@ -2503,6 +2568,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
Expand Down Expand Up @@ -2993,6 +3066,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
extern DECL_MMQ_CASE(GGML_TYPE_TQ2_0);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
Expand Down
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
type == GGML_TYPE_TQ2_0 ? vec_dot_tq2_0_q8_1 :
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
Expand All @@ -40,6 +41,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
type == GGML_TYPE_TQ2_0 ? VDR_TQ2_0_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
Expand Down Expand Up @@ -281,6 +283,13 @@ static void mul_mat_vec_q6_K_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

static void mul_mat_vec_tq2_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

mul_mat_vec_q_cuda<GGML_TYPE_TQ2_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

static void mul_mat_vec_iq2_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
Expand Down Expand Up @@ -398,6 +407,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_Q6_K:
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_TQ2_0:
mul_mat_vec_tq2_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/template-instances/generate_cu_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_TQ2_0",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_Q6_0"
]
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/template-instances/mmq-instance-tq2_0.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmq.cuh"

DECL_MMQ_CASE(GGML_TYPE_TQ2_0);
57 changes: 57 additions & 0 deletions ggml/src/ggml-cuda/vecdotq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,32 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
return d6 * sumf_d;
}

#define VDR_TQ2_0_Q8_1_MMVQ 2
#define VDR_TQ2_0_Q8_1_MMQ 8

// Can use the same for both mmvq and mmq, because there are no sub-scales in a TQ2_0 block
template <int vdr> static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_impl(
const int * __restrict__ v, const int * __restrict__ u, const float & d2, const float * __restrict__ d8) {

float sumf = 0.0f;

#pragma unroll
for (int i0 = 0; i0 < QR2_0; ++i0) {
int sumi = 0;

#pragma unroll
for (int i = 0; i < vdr; ++i) {
const int vi = (v[i] >> (2*i0)) & 0x03030303;

sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product
}

sumf += d8[i0] * sumi;
}

return d2 * sumf;
}

static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

Expand Down Expand Up @@ -837,6 +863,37 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
}

static __device__ __forceinline__ float vec_dot_tq2_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

const block_tq2_0 * btq2_0 = (const block_tq2_0 *) vbq + kbx;

// iqs 0..7 all need bq8_offset 0, 1, 2, 3
// iqs 8..15 all need bq8_offset 4, 5, 6, 7
const int bq8_offset = QR2_0 * (iqs / 8);

int v[VDR_TQ2_0_Q8_1_MMVQ];
int u[QR2_0*VDR_TQ2_0_Q8_1_MMVQ];
float d8[QR2_0];

#pragma unroll
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
v[i] = get_int_b2(btq2_0->qs, iqs + i);
}

#pragma unroll
for (int i0 = 0; i0 < QR2_0; ++i0) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i0;

for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
u[VDR_TQ2_0_Q8_1_MMVQ*i0 + i] = get_int_b4(bq8i->qs, (iqs % QI8_1) + i);
}
d8[i0] = __low2float(bq8i->ds);
}

return vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMVQ>(v, u, btq2_0->d, d8);
}

#define VDR_IQ2_XXS_Q8_1_MMVQ 2
#define VDR_IQ2_XXS_Q8_1_MMQ 2

Expand Down

0 comments on commit 79cde2c

Please sign in to comment.