diff --git a/kernel/riscv64/gemm_tcopy_8_rvv.c b/kernel/riscv64/gemm_tcopy_8_rvv.c index 4742ae6a75..c50b0d5b42 100644 --- a/kernel/riscv64/gemm_tcopy_8_rvv.c +++ b/kernel/riscv64/gemm_tcopy_8_rvv.c @@ -28,35 +28,19 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" #if !defined(DOUBLE) -#define VSETVL(n) __riscv_vsetvl_e32m1(n) -#define FLOAT_V_T vfloat32m1_t -#define FLOAT_VX2_T vfloat32m1x2_t -#define FLOAT_VX4_T vfloat32m1x4_t -#define FLOAT_VX8_T vfloat32m1x8_t -#define VLEV_FLOAT __riscv_vle32_v_f32m1 -#define VLSEV_FLOAT __riscv_vlse32_v_f32m1 -#define VSEV_FLOAT __riscv_vse32_v_f32m1 -#define VLSSEG2_FLOAT __riscv_vlsseg2e32_v_f32m1x2 -#define VSSEG2_FLOAT __riscv_vsseg2e32_v_f32m1x2 -#define VLSSEG4_FLOAT __riscv_vlsseg4e32_v_f32m1x4 -#define VSSEG4_FLOAT __riscv_vsseg4e32_v_f32m1x4 -#define VLSSEG8_FLOAT __riscv_vlsseg8e32_v_f32m1x8 -#define VSSEG8_FLOAT __riscv_vsseg8e32_v_f32m1x8 +#define FLOAT_V_T vfloat32m2_t +#define FLOAT_V_T_HALF vfloat32m1_t +#define VLEV_FLOAT __riscv_vle32_v_f32m2 +#define VLEV_FLOAT_HALF __riscv_vle32_v_f32m1 +#define VSEV_FLOAT __riscv_vse32_v_f32m2 +#define VSEV_FLOAT_HALF __riscv_vse32_v_f32m1 #else -#define VSETVL(n) __riscv_vsetvl_e64m1(n) -#define FLOAT_V_T vfloat64m1_t -#define FLOAT_VX2_T vfloat64m1x2_t -#define FLOAT_VX4_T vfloat64m1x4_t -#define FLOAT_VX8_T vfloat64m1x8_t -#define VLEV_FLOAT __riscv_vle64_v_f64m1 -#define VLSEV_FLOAT __riscv_vlse64_v_f64m1 -#define VSEV_FLOAT __riscv_vse64_v_f64m1 -#define VLSSEG2_FLOAT __riscv_vlsseg2e64_v_f64m1x2 -#define VSSEG2_FLOAT __riscv_vsseg2e64_v_f64m1x2 -#define VLSSEG4_FLOAT __riscv_vlsseg4e64_v_f64m1x4 -#define VSSEG4_FLOAT __riscv_vsseg4e64_v_f64m1x4 -#define VLSSEG8_FLOAT __riscv_vlsseg8e64_v_f64m1x8 -#define VSSEG8_FLOAT __riscv_vsseg8e64_v_f64m1x8 +#define FLOAT_V_T vfloat64m4_t +#define FLOAT_V_T_HALF vfloat64m2_t +#define VLEV_FLOAT __riscv_vle64_v_f64m4 +#define VLEV_FLOAT_HALF __riscv_vle64_v_f64m2 +#define VSEV_FLOAT __riscv_vse64_v_f64m4 +#define VSEV_FLOAT_HALF __riscv_vse64_v_f64m2 #endif int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) @@ -69,9 +53,7 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) IFLOAT *boffset, *boffset1, *boffset2, *boffset3, *boffset4; FLOAT_V_T v0; - FLOAT_VX2_T vx2; - FLOAT_VX4_T vx4; - FLOAT_VX8_T vx8; + FLOAT_V_T_HALF v1; // fprintf(stderr, "gemm_tcopy_8 m=%ld n=%ld lda=%ld\n", m, n, lda); @@ -81,156 +63,12 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) boffset3 = b + m * (n & ~3); boffset4 = b + m * (n & ~1); - for(j = (m >> 3); j > 0; j--) { - - aoffset1 = aoffset; - aoffset += 8 * lda; - - boffset1 = boffset; - boffset += 64; - - for(i = (n >> 3); i > 0; i--) { - size_t vl = 8; - - vx8 = VLSSEG8_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSSEG8_FLOAT(boffset1, vx8, vl); - - aoffset1 += 8; - boffset1 += m * 8; - } - - if (n & 4) { - size_t vl = 8; - - vx4 = VLSSEG4_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSSEG4_FLOAT(boffset2, vx4, vl); - - aoffset1 += 4; - boffset2 += 32; - } - - if (n & 2) { - size_t vl = 8; - - vx2 = VLSSEG2_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSSEG2_FLOAT(boffset3, vx2, vl); - - aoffset1 += 2; - boffset3 += 16; - } - - if (n & 1) { - size_t vl = 8; - - v0 = VLSEV_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSEV_FLOAT(boffset4, v0, vl); - - aoffset1 += 1; - boffset4 += 8; - } - - } - - if (m & 4) { - - aoffset1 = aoffset; - aoffset += 4 * lda; - - boffset1 = boffset; - boffset += 32; - - for(i = (n >> 3); i > 0; i--) { - size_t vl = 4; - - vx8 = VLSSEG8_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSSEG8_FLOAT(boffset1, vx8, vl); - - aoffset1 += 8; - boffset1 += m * 8; - } - - if (n & 4) { - size_t vl = 4; - - vx4 = VLSSEG4_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSSEG4_FLOAT(boffset2, vx4, vl); - - aoffset1 += 4; - boffset2 += 16; - } - - if (n & 2) { - size_t vl = 4; - - vx2 = VLSSEG2_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSSEG2_FLOAT(boffset3, vx2, vl); - - aoffset1 += 2; - boffset3 += 8; - } - - if (n & 1) { - size_t vl = 4; - - v0 = VLSEV_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSEV_FLOAT(boffset4, v0, vl); - - aoffset1 += 1; - boffset4 += 4; - } - } - - if (m & 2) { + for(j = m; j > 0; j--) { aoffset1 = aoffset; - aoffset += 2 * lda; - boffset1 = boffset; - boffset += 16; - - for(i = (n >> 3); i > 0; i--) { - size_t vl = 2; - vx8 = VLSSEG8_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSSEG8_FLOAT(boffset1, vx8, vl); - - aoffset1 += 8; - boffset1 += m * 8; - } - - if (n & 4) { - size_t vl = 2; - - vx4 = VLSSEG4_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSSEG4_FLOAT(boffset2, vx4, vl); - - aoffset1 += 4; - boffset2 += 8; - } - - if (n & 2) { - size_t vl = 2; - - vx2 = VLSSEG2_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSSEG2_FLOAT(boffset3, vx2, vl); - - aoffset1 += 2; - boffset3 += 4; - } - - if (n & 1) { - size_t vl = 2; - - v0 = VLSEV_FLOAT(aoffset1, lda * sizeof(FLOAT), vl); - VSEV_FLOAT(boffset4, v0, vl); - - aoffset1 += 1; - boffset4 += 2; - } - } - - if (m & 1) { - aoffset1 = aoffset; - boffset1 = boffset; + aoffset += lda; + boffset += 8; for(i = (n >> 3); i > 0; i--) { size_t vl = 8; @@ -245,27 +83,25 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) if (n & 4) { size_t vl = 4; - v0 = VLEV_FLOAT(aoffset1, vl); - VSEV_FLOAT(boffset2, v0, vl); + v1 = VLEV_FLOAT_HALF(aoffset1, vl); + VSEV_FLOAT_HALF(boffset2, v1, vl); aoffset1 += 4; - //boffset2 += 4; + boffset2 += 4; } if (n & 2) { - size_t vl = 2; - - v0 = VLEV_FLOAT(aoffset1, vl); - VSEV_FLOAT(boffset3, v0, vl); + *(boffset3) = *(aoffset1); + *(boffset3 + 1) = *(aoffset1 + 1); aoffset1 += 2; - // boffset3 += 2; + boffset3 += 2; } if (n & 1) { - *(boffset4) = *(aoffset1); - // aoffset1 ++; - // boffset4 ++; + *(boffset4) = *(aoffset1); + aoffset1 ++; + boffset4 ++; } }