Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-cuda : add TQ2_0 kernels, for ternary inference on GPU #11183

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ typedef sycl::half2 ggml_half2;
#define QI6_K (QK_K / (4*QR6_K))
#define QR6_K 2

#define QI2_0 (QK_K / (4*QR2_0))
#define QR2_0 4

#define QI2_XXS (QK_K / (4*QR2_XXS))
#define QR2_XXS 4

Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
static constexpr int qi = QI6_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_TQ2_0> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_0;
static constexpr int qi = QI2_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
static constexpr int qk = QK_K;
Expand Down
30 changes: 30 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,26 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
}

template<typename dst_t>
static __global__ void dequantize_block_tq2_0(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int64_t i = blockIdx.x;
const block_tq2_0 * x = (const block_tq2_0 *) vx;

const int64_t tid = threadIdx.x; // 0..64
const int64_t n = tid/32; // 0 or 1
const int64_t l = tid - 32*n; // 0..32

const uint8_t q = x[i].qs[32*n + l];
dst_t * y = yy + i*QK_K + 128*n;

float d = __half2float(x[i].d);
y[l+ 0] = d * ((q >> 0) & 3) - d;
y[l+32] = d * ((q >> 2) & 3) - d;
y[l+64] = d * ((q >> 4) & 3) - d;
y[l+96] = d * ((q >> 6) & 3) - d;
}

template<typename dst_t>
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {

Expand Down Expand Up @@ -515,6 +535,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_tq2_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_tq2_0<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
Expand Down Expand Up @@ -613,6 +639,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_TQ2_0:
return dequantize_row_tq2_0_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
Expand Down Expand Up @@ -660,6 +688,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_TQ2_0:
return dequantize_row_tq2_0_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2860,6 +2860,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ2_S:
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_Q6_K:
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
break;
case GGML_TYPE_TQ2_0:
mul_mat_q_case<GGML_TYPE_TQ2_0>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
break;
Expand Down Expand Up @@ -113,6 +116,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
Expand Down
74 changes: 74 additions & 0 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_Q5_K:
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
Expand Down Expand Up @@ -161,6 +162,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
Expand Down Expand Up @@ -195,6 +197,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
Expand Down Expand Up @@ -1808,6 +1811,68 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
#endif // INT8_MMA_AVAILABLE
}

template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_tile + 2*WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE

const int kqsx = threadIdx.x % QI2_0;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_0) {
int i = i0 + threadIdx.y*(WARP_SIZE/QI2_0) + threadIdx.x/QI2_0;

if (need_check) {
i = min(i, i_max);
}

const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
const int qs0 = get_int_b2(bxi->qs, kqsx);

#pragma unroll
for (int l = 0; l < QR2_0; ++l) {
// 0..7, 32..39
// 8..15, 40..47
// 16..23, 48..55
// 24..31, 56..63
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);

#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q;
#else
x_qs[i*(2*WARP_SIZE + 1) + k] = q;
#endif // INT8_MMA_AVAILABLE
}
}

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_0/2)) {
int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_0) + threadIdx.x/(QI2_0/2);

if (need_check) {
i = min(i, i_max);
}

const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;

const int k = threadIdx.x % (QI2_0/2);

#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d;
#else
x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d;
#endif // INT8_MMA_AVAILABLE
}
}

template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

Expand Down Expand Up @@ -2427,6 +2492,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
Expand Down Expand Up @@ -2916,6 +2989,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
extern DECL_MMQ_CASE(GGML_TYPE_TQ2_0);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
Expand Down
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
type == GGML_TYPE_TQ2_0 ? vec_dot_tq2_0_q8_1 :
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
Expand All @@ -37,6 +38,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
type == GGML_TYPE_TQ2_0 ? VDR_TQ2_0_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
Expand Down Expand Up @@ -271,6 +273,13 @@ static void mul_mat_vec_q6_K_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

static void mul_mat_vec_tq2_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

mul_mat_vec_q_cuda<GGML_TYPE_TQ2_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

static void mul_mat_vec_iq2_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
Expand Down Expand Up @@ -385,6 +394,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_Q6_K:
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_TQ2_0:
mul_mat_vec_tq2_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/template-instances/generate_cu_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_TQ2_0",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
]
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/template-instances/mmq-instance-tq2_0.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmq.cuh"

DECL_MMQ_CASE(GGML_TYPE_TQ2_0);
57 changes: 57 additions & 0 deletions ggml/src/ggml-cuda/vecdotq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,32 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
return d6 * sumf_d;
}

#define VDR_TQ2_0_Q8_1_MMVQ 2
#define VDR_TQ2_0_Q8_1_MMQ 8

// Can use the same for both mmvq and mmq, because there are no sub-scales in a TQ2_0 block
template <int vdr> static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_impl(
const int * __restrict__ v, const int * __restrict__ u, const float & d2, const float * __restrict__ d8) {

float sumf = 0.0f;

#pragma unroll
for (int i0 = 0; i0 < QR2_0; ++i0) {
int sumi = 0;

#pragma unroll
for (int i = 0; i < vdr; ++i) {
const int vi = (v[i] >> (2*i0)) & 0x03030303;

sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product
}

sumf += d8[i0] * sumi;
}

return d2 * sumf;
}

static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

Expand Down Expand Up @@ -786,6 +812,37 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
}

static __device__ __forceinline__ float vec_dot_tq2_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

const block_tq2_0 * btq2_0 = (const block_tq2_0 *) vbq + kbx;

// iqs 0..7 all need bq8_offset 0, 1, 2, 3
// iqs 8..15 all need bq8_offset 4, 5, 6, 7
const int bq8_offset = QR2_0 * (iqs / 8);

int v[VDR_TQ2_0_Q8_1_MMVQ];
int u[QR2_0*VDR_TQ2_0_Q8_1_MMVQ];
float d8[QR2_0];

#pragma unroll
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
v[i] = get_int_b2(btq2_0->qs, iqs + i);
}

#pragma unroll
for (int i0 = 0; i0 < QR2_0; ++i0) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i0;

for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
u[VDR_TQ2_0_Q8_1_MMVQ*i0 + i] = get_int_b4(bq8i->qs, (iqs % QI8_1) + i);
}
d8[i0] = __low2float(bq8i->ds);
}

return vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMVQ>(v, u, btq2_0->d, d8);
}

#define VDR_IQ2_XXS_Q8_1_MMVQ 2
#define VDR_IQ2_XXS_Q8_1_MMQ 2

Expand Down
7 changes: 5 additions & 2 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3375,7 +3375,8 @@ static const ggml_type all_types[] = {
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
// GGML_TYPE_TQ1_0,
GGML_TYPE_TQ2_0,
Comment on lines -3378 to +3379
Copy link
Collaborator Author

@compilade compilade Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An unintended side effect of un-commenting TQ2_0 here makes the Metal tests fail, as in https://github.com/ggerganov/llama.cpp/actions/runs/12716518343/job/35451025034?pr=11183#step:5:13921, because operations on that type are not yet implemented there and the ggml_metal_supports_op function isn't representative of the types supported by the Metal backend.

Some solutions are:

  • Implement all relevant TQ2_0 support for Metal
  • Make the ggml_metal_supports_op correctly return false when it should
    • Should be done for correctness
    • An "easy" way to temporarily do this would be similar to what was done for BF16 and simply return false when a TQ2_0 tensor is encountered. The same should be done for the other not-yet-supported types like TQ1_0.
  • Avoid testing TQ2_0 to hide the error
    • This doesn't fix the problem.

Most of these solutions (apart from hiding the problem) are out of scope of this PR which focuses on the CUDA implementation of TQ2_0. But I don't want this to make the Metal CI fail.

GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
Expand All @@ -3387,6 +3388,7 @@ static const ggml_type base_types[] = {
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1, // for I8MM tests
GGML_TYPE_Q4_K,
GGML_TYPE_TQ2_0,
GGML_TYPE_IQ2_XXS
};

Expand All @@ -3397,7 +3399,8 @@ static const ggml_type other_types[] = {
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
// GGML_TYPE_TQ1_0,
GGML_TYPE_TQ2_0,
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
Expand Down
Loading