Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Dec 6, 2024
1 parent a41edb0 commit 7d1958a
Showing 1 changed file with 121 additions and 62 deletions.
183 changes: 121 additions & 62 deletions src/layer/x86/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -2191,6 +2191,19 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
const int max_ii_packed = max_ii / elempack;
const int size = A.w * elempack;

#if __SSE2__
#if __AVX__
#if __AVX512F__
__m512 _v127_avx512 = _mm512_set1_ps(127.f);
__m512 _v127_B_scale_avx512 = _mm512_set1_ps(v127_B_scale);
#endif // __AVX512F__
__m256 _v127_avx = _mm256_set1_ps(127.f);
__m256 _v127_B_scale_avx = _mm256_set1_ps(v127_B_scale);
#endif // __AVX__
__m128 _v127 = _mm_set1_ps(127.f);
__m128 _v127_B_scale = _mm_set1_ps(v127_B_scale);
#endif // __SSE2__

for (int ii = 0; ii < max_ii_packed; ii++)
{
const float* ptr = (const float*)A + (i + ii * elempack) * A_hstep;
Expand Down Expand Up @@ -2242,8 +2255,8 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
#if __AVX512F__
if (elempack == 16)
{
__m512 _scale = _mm512_div_ps(_mm512_set1_ps(127.f), _absmax_avx512);
__m512 _out_descale = _mm512_div_ps(_absmax_avx512, _mm512_set1_ps(v127_B_scale));
__m512 _scale = _mm512_div_ps(_v127_avx512, _absmax_avx512);
__m512 _out_descale = _mm512_div_ps(_absmax_avx512, _v127_B_scale_avx512);
_mm512_store_ps(ps, _scale);
_mm512_store_ps(pods, _out_descale);
ps += 16;
Expand All @@ -2261,8 +2274,8 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
}
#endif // __AVX512F__

__m256 _scale = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax_avx);
__m256 _out_descale = _mm256_div_ps(_absmax_avx, _mm256_set1_ps(v127_B_scale));
__m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx);
__m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx);
_mm256_store_ps(ps, _scale);
_mm256_store_ps(pods, _out_descale);
ps += 8;
Expand All @@ -2288,8 +2301,8 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
}
#endif // __AVX__

__m128 _scale = _mm_div_ps(_mm_set1_ps(127.f), _absmax);
__m128 _out_descale = _mm_div_ps(_absmax, _mm_set1_ps(v127_B_scale));
__m128 _scale = _mm_div_ps(_v127, _absmax);
__m128 _out_descale = _mm_div_ps(_absmax, _v127_B_scale);
_mm_store_ps(ps, _scale);
_mm_store_ps(pods, _out_descale);
ps += 4;
Expand Down Expand Up @@ -4092,6 +4105,19 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,

const int max_ii_unpacked = max_ii * elempack;

#if __SSE2__
#if __AVX__
#if __AVX512F__
__m512 _v127_avx512 = _mm512_set1_ps(127.f);
__m512 _v127_B_scale_avx512 = _mm512_set1_ps(v127_B_scale);
#endif // __AVX512F__
__m256 _v127_avx = _mm256_set1_ps(127.f);
__m256 _v127_B_scale_avx = _mm256_set1_ps(v127_B_scale);
#endif // __AVX__
__m128 _v127 = _mm_set1_ps(127.f);
__m128 _v127_B_scale = _mm_set1_ps(v127_B_scale);
#endif // __SSE2__

int ii = 0;
#if __SSE2__
#if __AVX__
Expand Down Expand Up @@ -4135,8 +4161,8 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
__m128 _absmax0 = _mm256_extractf128_ps(_absmax_avx, 0);
__m128 _absmax1 = _mm256_extractf128_ps(_absmax_avx, 1);
__m128 _absmax = _mm_max_ps(_absmax0, _absmax1);
__m128 _scale0 = _mm_div_ps(_mm_set1_ps(127.f), _absmax);
__m128 _out_descale0 = _mm_div_ps(_absmax, _mm_set1_ps(v127_B_scale));
__m128 _scale0 = _mm_div_ps(_v127, _absmax);
__m128 _out_descale0 = _mm_div_ps(_absmax, _v127_B_scale);
_mm_store_ps(ps, _scale0);
_mm_store_ps(pods, _out_descale0);
ps += 4;
Expand All @@ -4158,8 +4184,8 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
__m256 _absmax0_avx = _mm512_extractf32x8_ps(_absmax_avx512, 0);
__m256 _absmax1_avx = _mm512_extractf32x8_ps(_absmax_avx512, 1);
__m256 _absmax_avx = _mm256_max_ps(_absmax0_avx, _absmax1_avx);
__m256 _scale = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax_avx);
__m256 _out_descale = _mm256_div_ps(_absmax_avx, _mm256_set1_ps(v127_B_scale));
__m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx);
__m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx);
_mm256_store_ps(ps, _scale);
_mm256_store_ps(pods, _out_descale);
ps += 8;
Expand All @@ -4178,23 +4204,23 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
_tmp1 = _mm512_unpackhi_ps(_absmax0_avx512, _absmax1_avx512);
__m512 _absmax_avx512 = _mm512_max_ps(_tmp0, _tmp1);
_absmax_avx512 = _mm512_permutexvar_ps(_mm512_setr_epi32(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15), _absmax_avx512);
__m512 _scale0 = _mm512_div_ps(_mm512_set1_ps(127.f), _absmax_avx512);
__m512 _out_descale0 = _mm512_div_ps(_absmax_avx512, _mm512_set1_ps(v127_B_scale));
__m512 _scale0 = _mm512_div_ps(_v127_avx512, _absmax_avx512);
__m512 _out_descale0 = _mm512_div_ps(_absmax_avx512, _v127_B_scale_avx512);
_mm512_store_ps(ps, _scale0);
_mm512_store_ps(pods, _out_descale0);
ps += 16;
pods += 16;
}
if (elempack == 1)
{
__m512 _scale0 = _mm512_div_ps(_mm512_set1_ps(127.f), _absmax0_avx512);
__m512 _scale1 = _mm512_div_ps(_mm512_set1_ps(127.f), _absmax1_avx512);
__m512 _scale2 = _mm512_div_ps(_mm512_set1_ps(127.f), _absmax2_avx512);
__m512 _scale3 = _mm512_div_ps(_mm512_set1_ps(127.f), _absmax3_avx512);
__m512 _out_descale0 = _mm512_div_ps(_absmax0_avx512, _mm512_set1_ps(v127_B_scale));
__m512 _out_descale1 = _mm512_div_ps(_absmax1_avx512, _mm512_set1_ps(v127_B_scale));
__m512 _out_descale2 = _mm512_div_ps(_absmax2_avx512, _mm512_set1_ps(v127_B_scale));
__m512 _out_descale3 = _mm512_div_ps(_absmax3_avx512, _mm512_set1_ps(v127_B_scale));
__m512 _scale0 = _mm512_div_ps(_v127_avx512, _absmax0_avx512);
__m512 _scale1 = _mm512_div_ps(_v127_avx512, _absmax1_avx512);
__m512 _scale2 = _mm512_div_ps(_v127_avx512, _absmax2_avx512);
__m512 _scale3 = _mm512_div_ps(_v127_avx512, _absmax3_avx512);
__m512 _out_descale0 = _mm512_div_ps(_absmax0_avx512, _v127_B_scale_avx512);
__m512 _out_descale1 = _mm512_div_ps(_absmax1_avx512, _v127_B_scale_avx512);
__m512 _out_descale2 = _mm512_div_ps(_absmax2_avx512, _v127_B_scale_avx512);
__m512 _out_descale3 = _mm512_div_ps(_absmax3_avx512, _v127_B_scale_avx512);
_mm512_store_ps(ps, _scale0);
_mm512_store_ps(ps + 16, _scale1);
_mm512_store_ps(ps + 32, _scale2);
Expand Down Expand Up @@ -4298,8 +4324,8 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
__m128 _absmax0 = _mm256_extractf128_ps(_t0, 0);
__m128 _absmax1 = _mm256_extractf128_ps(_t0, 1);
__m128 _absmax = _mm_max_ps(_absmax0, _absmax1);
__m128 _scale0 = _mm_div_ps(_mm_set1_ps(127.f), _absmax);
__m128 _out_descale0 = _mm_div_ps(_absmax, _mm_set1_ps(v127_B_scale));
__m128 _scale0 = _mm_div_ps(_v127, _absmax);
__m128 _out_descale0 = _mm_div_ps(_absmax, _v127_B_scale);
_mm_store_ps(ps, _scale0);
_mm_store_ps(pods, _out_descale0);
ps += 4;
Expand Down Expand Up @@ -4331,8 +4357,8 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
__m128 _absmax0 = _mm_unpacklo_ps(_tt0, _tt1);
__m128 _absmax1 = _mm_unpackhi_ps(_tt0, _tt1);
_absmax_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_absmax0), _absmax1, 1);
__m256 _scale = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax_avx);
__m256 _out_descale = _mm256_div_ps(_absmax_avx, _mm256_set1_ps(v127_B_scale));
__m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx);
__m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx);
_mm256_store_ps(ps, _scale);
_mm256_store_ps(pods, _out_descale);
ps += 8;
Expand All @@ -4341,23 +4367,23 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
if (elempack == 1)
{
#if __AVX512F__
__m512 _scale0 = _mm512_div_ps(_mm512_set1_ps(127.f), _absmax0_avx512);
__m512 _scale1 = _mm512_div_ps(_mm512_set1_ps(127.f), _absmax1_avx512);
__m512 _out_descale0 = _mm512_div_ps(_absmax0_avx512, _mm512_set1_ps(v127_B_scale));
__m512 _out_descale1 = _mm512_div_ps(_absmax1_avx512, _mm512_set1_ps(v127_B_scale));
__m512 _scale0 = _mm512_div_ps(_v127_avx512, _absmax0_avx512);
__m512 _scale1 = _mm512_div_ps(_v127_avx512, _absmax1_avx512);
__m512 _out_descale0 = _mm512_div_ps(_absmax0_avx512, _v127_B_scale_avx512);
__m512 _out_descale1 = _mm512_div_ps(_absmax1_avx512, _v127_B_scale_avx512);
_mm512_store_ps(ps, _scale0);
_mm512_store_ps(ps + 16, _scale1);
_mm512_store_ps(pods, _out_descale0);
_mm512_store_ps(pods + 16, _out_descale1);
#else
__m256 _scale0 = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax0_avx);
__m256 _scale1 = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax1_avx);
__m256 _scale2 = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax2_avx);
__m256 _scale3 = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax3_avx);
__m256 _out_descale0 = _mm256_div_ps(_absmax0_avx, _mm256_set1_ps(v127_B_scale));
__m256 _out_descale1 = _mm256_div_ps(_absmax1_avx, _mm256_set1_ps(v127_B_scale));
__m256 _out_descale2 = _mm256_div_ps(_absmax2_avx, _mm256_set1_ps(v127_B_scale));
__m256 _out_descale3 = _mm256_div_ps(_absmax3_avx, _mm256_set1_ps(v127_B_scale));
__m256 _scale0 = _mm256_div_ps(_v127_avx, _absmax0_avx);
__m256 _scale1 = _mm256_div_ps(_v127_avx, _absmax1_avx);
__m256 _scale2 = _mm256_div_ps(_v127_avx, _absmax2_avx);
__m256 _scale3 = _mm256_div_ps(_v127_avx, _absmax3_avx);
__m256 _out_descale0 = _mm256_div_ps(_absmax0_avx, _v127_B_scale_avx);
__m256 _out_descale1 = _mm256_div_ps(_absmax1_avx, _v127_B_scale_avx);
__m256 _out_descale2 = _mm256_div_ps(_absmax2_avx, _v127_B_scale_avx);
__m256 _out_descale3 = _mm256_div_ps(_absmax3_avx, _v127_B_scale_avx);
_mm256_store_ps(ps, _scale0);
_mm256_store_ps(ps + 8, _scale1);
_mm256_store_ps(ps + 16, _scale2);
Expand Down Expand Up @@ -4510,8 +4536,8 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
__m128 _t0 = _mm_unpacklo_ps(_absmax0, _absmax1);
__m128 _t1 = _mm_unpackhi_ps(_absmax0, _absmax1);
__m128 _absmax = _mm_max_ps(_t0, _t1);
__m128 _scale0 = _mm_div_ps(_mm_set1_ps(127.f), _absmax);
__m128 _out_descale0 = _mm_div_ps(_absmax, _mm_set1_ps(v127_B_scale));
__m128 _scale0 = _mm_div_ps(_v127, _absmax);
__m128 _out_descale0 = _mm_div_ps(_absmax, _v127_B_scale);
_mm_store_ps(ps, _scale0);
_mm_store_ps(pods, _out_descale0);
ps += 4;
Expand All @@ -4520,28 +4546,28 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
if (elempack == 1)
{
#if __AVX512F__
__m512 _scale = _mm512_div_ps(_mm512_set1_ps(127.f), _absmax_avx512);
__m512 _out_descale = _mm512_div_ps(_absmax_avx512, _mm512_set1_ps(v127_B_scale));
__m512 _scale = _mm512_div_ps(_v127_avx512, _absmax_avx512);
__m512 _out_descale = _mm512_div_ps(_absmax_avx512, _v127_B_scale_avx512);
_mm512_store_ps(ps, _scale);
_mm512_store_ps(pods, _out_descale);
#elif __AVX__
__m256 _scale0 = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax0_avx);
__m256 _scale1 = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax1_avx);
__m256 _out_descale0 = _mm256_div_ps(_absmax0_avx, _mm256_set1_ps(v127_B_scale));
__m256 _out_descale1 = _mm256_div_ps(_absmax1_avx, _mm256_set1_ps(v127_B_scale));
__m256 _scale0 = _mm256_div_ps(_v127_avx, _absmax0_avx);
__m256 _scale1 = _mm256_div_ps(_v127_avx, _absmax1_avx);
__m256 _out_descale0 = _mm256_div_ps(_absmax0_avx, _v127_B_scale_avx);
__m256 _out_descale1 = _mm256_div_ps(_absmax1_avx, _v127_B_scale_avx);
_mm256_store_ps(ps, _scale0);
_mm256_store_ps(ps + 8, _scale1);
_mm256_store_ps(pods, _out_descale0);
_mm256_store_ps(pods + 8, _out_descale1);
#else
__m128 _scale0 = _mm_div_ps(_mm_set1_ps(127.f), _absmax0);
__m128 _scale1 = _mm_div_ps(_mm_set1_ps(127.f), _absmax1);
__m128 _scale2 = _mm_div_ps(_mm_set1_ps(127.f), _absmax2);
__m128 _scale3 = _mm_div_ps(_mm_set1_ps(127.f), _absmax3);
__m128 _out_descale0 = _mm_div_ps(_absmax0, _mm_set1_ps(v127_B_scale));
__m128 _out_descale1 = _mm_div_ps(_absmax1, _mm_set1_ps(v127_B_scale));
__m128 _out_descale2 = _mm_div_ps(_absmax2, _mm_set1_ps(v127_B_scale));
__m128 _out_descale3 = _mm_div_ps(_absmax3, _mm_set1_ps(v127_B_scale));
__m128 _scale0 = _mm_div_ps(_v127, _absmax0);
__m128 _scale1 = _mm_div_ps(_v127, _absmax1);
__m128 _scale2 = _mm_div_ps(_v127, _absmax2);
__m128 _scale3 = _mm_div_ps(_v127, _absmax3);
__m128 _out_descale0 = _mm_div_ps(_absmax0, _v127_B_scale);
__m128 _out_descale1 = _mm_div_ps(_absmax1, _v127_B_scale);
__m128 _out_descale2 = _mm_div_ps(_absmax2, _v127_B_scale);
__m128 _out_descale3 = _mm_div_ps(_absmax3, _v127_B_scale);
_mm_store_ps(ps, _scale0);
_mm_store_ps(ps + 4, _scale1);
_mm_store_ps(ps + 8, _scale2);
Expand Down Expand Up @@ -4655,15 +4681,15 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
if (elempack == 1)
{
#if __AVX__
__m256 _scale = _mm256_div_ps(_mm256_set1_ps(127.f), _absmax_avx);
__m256 _out_descale = _mm256_div_ps(_absmax_avx, _mm256_set1_ps(v127_B_scale));
__m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx);
__m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx);
_mm256_store_ps(ps, _scale);
_mm256_store_ps(pods, _out_descale);
#else
__m128 _scale0 = _mm_div_ps(_mm_set1_ps(127.f), _absmax0);
__m128 _scale1 = _mm_div_ps(_mm_set1_ps(127.f), _absmax1);
__m128 _out_descale0 = _mm_div_ps(_absmax0, _mm_set1_ps(v127_B_scale));
__m128 _out_descale1 = _mm_div_ps(_absmax1, _mm_set1_ps(v127_B_scale));
__m128 _scale0 = _mm_div_ps(_v127, _absmax0);
__m128 _scale1 = _mm_div_ps(_v127, _absmax1);
__m128 _out_descale0 = _mm_div_ps(_absmax0, _v127_B_scale);
__m128 _out_descale1 = _mm_div_ps(_absmax1, _v127_B_scale);
_mm_store_ps(ps, _scale0);
_mm_store_ps(ps + 4, _scale1);
_mm_store_ps(pods, _out_descale0);
Expand Down Expand Up @@ -4743,8 +4769,8 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
if (elempack == 1)
{
#if __SSE2__
__m128 _scale = _mm_div_ps(_mm_set1_ps(127.f), _absmax);
__m128 _out_descale = _mm_div_ps(_absmax, _mm_set1_ps(v127_B_scale));
__m128 _scale = _mm_div_ps(_v127, _absmax);
__m128 _out_descale = _mm_div_ps(_absmax, _v127_B_scale);
_mm_store_ps(ps, _scale);
_mm_store_ps(pods, _out_descale);
#else
Expand All @@ -4768,7 +4794,26 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
float absmax0 = 0.f;
float absmax1 = 0.f;

for (int kk = 0; kk < K; kk++)
int kk = 0;
#if __AVX512F__
__m512i _vindex0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
_vindex0 = _mm512_mullo_epi32(_vindex0, _mm512_set1_epi32(A_hstep));
__m512i _vindex1 = _mm512_add_epi32(_vindex0, _mm512_set1_epi32(1));

__m512 _absmax0_avx512 = _mm512_setzero_ps();
__m512 _absmax1_avx512 = _mm512_setzero_ps();
for (; kk + 15 < K; kk += 16)
{
__m512 _p0 = _mm512_i32gather_ps(_vindex0, ptr, sizeof(float));
__m512 _p1 = _mm512_i32gather_ps(_vindex1, ptr, sizeof(float));
_absmax0_avx512 = _mm512_max_ps(_absmax0_avx512, abs512_ps(_p0));
_absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, abs512_ps(_p1));
ptr += A_hstep * 16;
}
absmax0 = _mm512_comp_reduce_max_ps(_absmax0_avx512);
absmax1 = _mm512_comp_reduce_max_ps(_absmax1_avx512);
#endif // __AVX512F__
for (; kk < K; kk++)
{
absmax0 = std::max(absmax0, (float)fabsf(ptr[0]));
absmax1 = std::max(absmax1, (float)fabsf(ptr[1]));
Expand All @@ -4788,7 +4833,21 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,

float absmax = 0.f;

for (int kk = 0; kk < K; kk++)
int kk = 0;
#if __AVX512F__
__m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
_vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A_hstep));

__m512 _absmax_avx512 = _mm512_setzero_ps();
for (; kk + 15 < K; kk += 16)
{
__m512 _p = _mm512_i32gather_ps(_vindex, ptr, sizeof(float));
_absmax_avx512 = _mm512_max_ps(_absmax_avx512, abs512_ps(_p));
ptr += A_hstep * 16;
}
absmax = _mm512_comp_reduce_max_ps(_absmax_avx512);
#endif // __AVX512F__
for (; kk < K; kk++)
{
absmax = std::max(absmax, (float)fabsf(ptr[0]));
ptr += A_hstep;
Expand Down

0 comments on commit 7d1958a

Please sign in to comment.