From dbd962dd81977501d9ca103528441de9ab7996b4 Mon Sep 17 00:00:00 2001 From: "tingbo.liao" Date: Tue, 7 Jan 2025 11:17:28 +0800 Subject: [PATCH] Rearranged the rotm optimized codes to adapt to the architecture. Signed-off-by: tingbo.liao --- cmake/kernel.cmake | 2 + common_d.h | 1 + common_level1.h | 5 +- common_macro.h | 2 + common_param.h | 2 + common_s.h | 1 + interface/rotm.c | 140 +------------- kernel/CMakeLists.txt | 3 + kernel/Makefile.L1 | 18 +- kernel/riscv64/KERNEL.RISCV64_GENERIC | 3 + kernel/riscv64/KERNEL.x280 | 3 + kernel/riscv64/rotm.c | 159 +++++++++++++++ kernel/riscv64/rotm_rvv.c | 266 ++++++++++++++++++++++++++ utest/test_rot.c | 36 ++++ 14 files changed, 503 insertions(+), 138 deletions(-) create mode 100644 kernel/riscv64/rotm.c create mode 100644 kernel/riscv64/rotm_rvv.c diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index efededcf36..9f8f5d0d77 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -79,6 +79,8 @@ macro(SetDefaultL1) SetFallback(CROTKERNEL zrot.S) SetFallback(ZROTKERNEL zrot.S) SetFallback(XROTKERNEL zrot.S) + SetFallback(SROTMKERNEL rotm.S) + SetFallback(DROTMKERNEL rotm.S) SetFallback(SSCALKERNEL scal.S) SetFallback(DSCALKERNEL scal.S) SetFallback(CSCALKERNEL zscal.S) diff --git a/common_d.h b/common_d.h index 6f4bb2dedc..5b9cffca85 100644 --- a/common_d.h +++ b/common_d.h @@ -22,6 +22,7 @@ #define DSUM_K dsum_k #define DSWAP_K dswap_k #define DROT_K drot_k +#define DROTM_K drotm_k #define DGEMV_N dgemv_n #define DGEMV_T dgemv_t diff --git a/common_level1.h b/common_level1.h index d2ed47e567..afc1fff3de 100644 --- a/common_level1.h +++ b/common_level1.h @@ -1,3 +1,4 @@ + /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ /* All rights reserved. */ @@ -213,8 +214,8 @@ int srotmg_k(float *, float *, float *, float *, float *); int drotmg_k(double *, double *, double *, double *, double *); int qrotmg_k(xdouble *, xdouble *, xdouble *, xdouble *, xdouble *); -int srotm_k (BLASLONG, float, BLASLONG, float, BLASLONG, float); -int drotm_k (BLASLONG, double, BLASLONG, double, BLASLONG, double); +int srotm_k (BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); +int drotm_k (BLASLONG, double *, BLASLONG, double *, BLASLONG, double *); int qrotm_k (BLASLONG, xdouble, BLASLONG, xdouble, BLASLONG, xdouble); diff --git a/common_macro.h b/common_macro.h index a924651de2..171ccc15d5 100644 --- a/common_macro.h +++ b/common_macro.h @@ -361,6 +361,7 @@ #define SUM_K DSUM_K #define SWAP_K DSWAP_K #define ROT_K DROT_K +#define ROTM_K DROTM_K #define GEMV_N DGEMV_N #define GEMV_T DGEMV_T @@ -977,6 +978,7 @@ #define SUM_K SSUM_K #define SWAP_K SSWAP_K #define ROT_K SROT_K +#define ROTM_K SROTM_K #define GEMV_N SGEMV_N #define GEMV_T SGEMV_T diff --git a/common_param.h b/common_param.h index c082d248e8..71df4ae2e5 100644 --- a/common_param.h +++ b/common_param.h @@ -197,6 +197,7 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); //double (*dsdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); int (*srot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); + int (*srotm_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); #endif #if (BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) int (*saxpy_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); @@ -330,6 +331,7 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); #endif #if (BUILD_DOUBLE==1) || (BUILD_COMPLEX16==1) int (*drot_k) (BLASLONG, double *, BLASLONG, double *, BLASLONG, double, double); + int (*drotm_k) (BLASLONG, double *, BLASLONG, double *, BLASLONG, double *); int (*daxpy_k) (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG); int (*dscal_k) (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG); int (*dswap_k) (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG); diff --git a/common_s.h b/common_s.h index fdd80b62f6..e996fbd73e 100644 --- a/common_s.h +++ b/common_s.h @@ -24,6 +24,7 @@ #define SSCAL_K sscal_k #define SSWAP_K sswap_k #define SROT_K srot_k +#define SROTM_K srotm_k #define SGEMV_N sgemv_n #define SGEMV_T sgemv_t diff --git a/interface/rotm.c b/interface/rotm.c index 9dc08354ac..9ef87da329 100644 --- a/interface/rotm.c +++ b/interface/rotm.c @@ -7,149 +7,21 @@ void NAME(blasint *N, FLOAT *dx, blasint *INCX, FLOAT *dy, blasint *INCY, FLOAT *dparam){ - blasint n = *N; - blasint incx = *INCX; - blasint incy = *INCY; + blasint n = *N; + blasint incx = *INCX; + blasint incy = *INCY; + PRINT_DEBUG_NAME #else void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *dparam){ -#endif - - blasint i__1, i__2; + PRINT_DEBUG_CNAME; - blasint i__; - FLOAT w, z__; - blasint kx, ky; - FLOAT dh11, dh12, dh22, dh21, dflag; - blasint nsteps; - -#ifndef CBLAS - PRINT_DEBUG_CNAME; -#else - PRINT_DEBUG_CNAME; #endif - --dparam; - --dy; - --dx; - - dflag = dparam[1]; - if (n <= 0 || dflag == - 2.0) goto L140; - - if (! (incx == incy && incx > 0)) goto L70; - - nsteps = n * incx; - if (dflag < 0.) { - goto L50; - } else if (dflag == 0) { - goto L10; - } else { - goto L30; - } -L10: - dh12 = dparam[4]; - dh21 = dparam[3]; - i__1 = nsteps; - i__2 = incx; - for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) { - w = dx[i__]; - z__ = dy[i__]; - dx[i__] = w + z__ * dh12; - dy[i__] = w * dh21 + z__; -/* L20: */ - } - goto L140; -L30: - dh11 = dparam[2]; - dh22 = dparam[5]; - i__2 = nsteps; - i__1 = incx; - for (i__ = 1; i__1 < 0 ? i__ >= i__2 : i__ <= i__2; i__ += i__1) { - w = dx[i__]; - z__ = dy[i__]; - dx[i__] = w * dh11 + z__; - dy[i__] = -w + dh22 * z__; -/* L40: */ - } - goto L140; -L50: - dh11 = dparam[2]; - dh12 = dparam[4]; - dh21 = dparam[3]; - dh22 = dparam[5]; - i__1 = nsteps; - i__2 = incx; - for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) { - w = dx[i__]; - z__ = dy[i__]; - dx[i__] = w * dh11 + z__ * dh12; - dy[i__] = w * dh21 + z__ * dh22; -/* L60: */ - } - goto L140; -L70: - kx = 1; - ky = 1; - if (incx < 0) { - kx = (1 - n) * incx + 1; - } - if (incy < 0) { - ky = (1 - n) * incy + 1; - } + ROTM_K(n, dx, incx, dy, incy, dparam); - if (dflag < 0.) { - goto L120; - } else if (dflag == 0) { - goto L80; - } else { - goto L100; - } -L80: - dh12 = dparam[4]; - dh21 = dparam[3]; - i__2 = n; - for (i__ = 1; i__ <= i__2; ++i__) { - w = dx[kx]; - z__ = dy[ky]; - dx[kx] = w + z__ * dh12; - dy[ky] = w * dh21 + z__; - kx += incx; - ky += incy; -/* L90: */ - } - goto L140; -L100: - dh11 = dparam[2]; - dh22 = dparam[5]; - i__2 = n; - for (i__ = 1; i__ <= i__2; ++i__) { - w = dx[kx]; - z__ = dy[ky]; - dx[kx] = w * dh11 + z__; - dy[ky] = -w + dh22 * z__; - kx += incx; - ky += incy; -/* L110: */ - } - goto L140; -L120: - dh11 = dparam[2]; - dh12 = dparam[4]; - dh21 = dparam[3]; - dh22 = dparam[5]; - i__2 = n; - for (i__ = 1; i__ <= i__2; ++i__) { - w = dx[kx]; - z__ = dy[ky]; - dx[kx] = w * dh11 + z__ * dh12; - dy[ky] = w * dh21 + z__ * dh22; - kx += incx; - ky += incy; -/* L130: */ - } -L140: return; } diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 74e6760c27..bc713e6033 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -125,6 +125,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${SNRM2KERNEL}" "" "nrm2_k" false "" "" false "SINGLE") GenerateNamedObjects("${KERNELDIR}/${SDOTKERNEL}" "" "dot_k" false "" "" false "SINGLE") GenerateNamedObjects("${KERNELDIR}/${SROTKERNEL}" "" "rot_k" false "" "" false "SINGLE") + GenerateNamedObjects("${KERNELDIR}/${SROTMKERNEL}" "" "rotm_k" false "" "" false "SINGLE") endif () if (BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) GenerateNamedObjects("${KERNELDIR}/${DAMAXKERNEL}" "USE_ABS" "amax_k" false "" "" false "DOUBLE") @@ -148,6 +149,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${DCOPYKERNEL}" "C_INTERFACE" "copy_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DNRM2KERNEL}" "" "nrm2_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DROTKERNEL}" "" "rot_k" false "" "" false "DOUBLE") + GenerateNamedObjects("${KERNELDIR}/${DROTMKERNEL}" "" "rotm_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DDOTKERNEL}" "" "dot_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DSWAPKERNEL}" "" "swap_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DAXPYKERNEL}" "" "axpy_k" false "" "" false "DOUBLE") @@ -1105,6 +1107,7 @@ endif () GenerateNamedObjects("${KERNELDIR}/${DCOPYKERNEL}" "C_INTERFACE" "copy_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DNRM2KERNEL}" "" "nrm2_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DROTKERNEL}" "" "rot_k" false "" "" false "DOUBLE") + GenerateNamedObjects("${KERNELDIR}/${DROTMKERNEL}" "" "rotm_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DDOTKERNEL}" "" "dot_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DSWAPKERNEL}" "" "swap_k" false "" "" false "DOUBLE") GenerateNamedObjects("${KERNELDIR}/${DAXPYKERNEL}" "" "axpy_k" false "" "" false "DOUBLE") diff --git a/kernel/Makefile.L1 b/kernel/Makefile.L1 index 09337363da..e67aea7980 100644 --- a/kernel/Makefile.L1 +++ b/kernel/Makefile.L1 @@ -336,6 +336,14 @@ ifndef XROTKERNEL XROTKERNEL = zrot.S endif +ifndef SROTMKERNEL +SROTMKERNEL = rotm.S +endif + +ifndef DROTMKERNEL +DROTMKERNEL = rotm.S +endif + ### SCAL ### ifndef SSCALKERNEL @@ -504,14 +512,14 @@ SBLASOBJS += \ sasum_k$(TSUFFIX).$(SUFFIX) ssum_k$(TSUFFIX).$(SUFFIX) saxpy_k$(TSUFFIX).$(SUFFIX) scopy_k$(TSUFFIX).$(SUFFIX) \ sdot_k$(TSUFFIX).$(SUFFIX) sdsdot_k$(TSUFFIX).$(SUFFIX) dsdot_k$(TSUFFIX).$(SUFFIX) \ snrm2_k$(TSUFFIX).$(SUFFIX) srot_k$(TSUFFIX).$(SUFFIX) sscal_k$(TSUFFIX).$(SUFFIX) sswap_k$(TSUFFIX).$(SUFFIX) \ - saxpby_k$(TSUFFIX).$(SUFFIX) + saxpby_k$(TSUFFIX).$(SUFFIX) srotm_k$(TSUFFIX).$(SUFFIX) DBLASOBJS += \ damax_k$(TSUFFIX).$(SUFFIX) damin_k$(TSUFFIX).$(SUFFIX) dmax_k$(TSUFFIX).$(SUFFIX) dmin_k$(TSUFFIX).$(SUFFIX) \ idamax_k$(TSUFFIX).$(SUFFIX) idamin_k$(TSUFFIX).$(SUFFIX) idmax_k$(TSUFFIX).$(SUFFIX) idmin_k$(TSUFFIX).$(SUFFIX) \ dasum_k$(TSUFFIX).$(SUFFIX) daxpy_k$(TSUFFIX).$(SUFFIX) dcopy_k$(TSUFFIX).$(SUFFIX) ddot_k$(TSUFFIX).$(SUFFIX) \ dnrm2_k$(TSUFFIX).$(SUFFIX) drot_k$(TSUFFIX).$(SUFFIX) dscal_k$(TSUFFIX).$(SUFFIX) dswap_k$(TSUFFIX).$(SUFFIX) \ - daxpby_k$(TSUFFIX).$(SUFFIX) dsum_k$(TSUFFIX).$(SUFFIX) + daxpby_k$(TSUFFIX).$(SUFFIX) dsum_k$(TSUFFIX).$(SUFFIX) drotm_k$(TSUFFIX).$(SUFFIX) QBLASOBJS += \ qamax_k$(TSUFFIX).$(SUFFIX) qamin_k$(TSUFFIX).$(SUFFIX) qmax_k$(TSUFFIX).$(SUFFIX) qmin_k$(TSUFFIX).$(SUFFIX) \ @@ -841,6 +849,12 @@ $(KDIR)srot_k$(TSUFFIX).$(SUFFIX) $(KDIR)srot_k$(TPSUFFIX).$(PSUFFIX) : $(KERN $(KDIR)drot_k$(TSUFFIX).$(SUFFIX) $(KDIR)drot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DROTKERNEL) $(CC) -c $(CFLAGS) $(FMAFLAG) -UCOMPLEX -UCOMPLEX -DDOUBLE $< -o $@ +$(KDIR)srotm_k$(TSUFFIX).$(SUFFIX) $(KDIR)srotm_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SROTMKERNEL) + $(CC) -c $(CFLAGS) $(FMAFLAG) -UCOMPLEX -UCOMPLEX -UDOUBLE $< -o $@ + +$(KDIR)drotm_k$(TSUFFIX).$(SUFFIX) $(KDIR)drotm_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DROTMKERNEL) + $(CC) -c $(CFLAGS) $(FMAFLAG) -UCOMPLEX -UCOMPLEX -DDOUBLE $< -o $@ + $(KDIR)qrot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qrot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QROTKERNEL) $(CC) -c $(CFLAGS) -UCOMPLEX -UCOMPLEX -DXDOUBLE $< -o $@ diff --git a/kernel/riscv64/KERNEL.RISCV64_GENERIC b/kernel/riscv64/KERNEL.RISCV64_GENERIC index 67f81cacda..e27e472e0b 100644 --- a/kernel/riscv64/KERNEL.RISCV64_GENERIC +++ b/kernel/riscv64/KERNEL.RISCV64_GENERIC @@ -71,6 +71,9 @@ DROTKERNEL = ../riscv64/rot.c CROTKERNEL = ../riscv64/zrot.c ZROTKERNEL = ../riscv64/zrot.c +SROTMKERNEL = ../riscv64/rotm.c +DROTMKERNEL = ../riscv64/rotm.c + SSCALKERNEL = ../riscv64/scal.c DSCALKERNEL = ../riscv64/scal.c CSCALKERNEL = ../riscv64/zscal.c diff --git a/kernel/riscv64/KERNEL.x280 b/kernel/riscv64/KERNEL.x280 index 86708fe015..d04ba2224e 100644 --- a/kernel/riscv64/KERNEL.x280 +++ b/kernel/riscv64/KERNEL.x280 @@ -98,6 +98,9 @@ DROTKERNEL = rot_rvv.c CROTKERNEL = zrot_rvv.c ZROTKERNEL = zrot_rvv.c +SROTMKERNEL = rotm_rvv.c +DROTMKERNEL = rotm_rvv.c + SSCALKERNEL = scal_rvv.c DSCALKERNEL = scal_rvv.c CSCALKERNEL = zscal_rvv.c diff --git a/kernel/riscv64/rotm.c b/kernel/riscv64/rotm.c new file mode 100644 index 0000000000..e151aa5f88 --- /dev/null +++ b/kernel/riscv64/rotm.c @@ -0,0 +1,159 @@ +/*************************************************************************** +Copyright (c) 2013, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG n, FLOAT *dx, BLASLONG incx, FLOAT *dy, BLASLONG incy, FLOAT *dparam) +{ + BLASLONG i__1, i__2; + BLASLONG i__; + FLOAT w, z__; + BLASLONG kx, ky; + FLOAT dh11, dh12, dh22, dh21, dflag; + BLASLONG nsteps; + + --dparam; + --dy; + --dx; + + dflag = dparam[1]; + if (n <= 0 || dflag == - 2.0) goto L140; + + if (! (incx == incy && incx > 0)) goto L70; + + nsteps = n * incx; + if (dflag < 0.) { + goto L50; + } else if (dflag == 0) { + goto L10; + } else { + goto L30; + } +L10: + dh12 = dparam[4]; + dh21 = dparam[3]; + i__1 = nsteps; + i__2 = incx; + for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) { + w = dx[i__]; + z__ = dy[i__]; + dx[i__] = w + z__ * dh12; + dy[i__] = w * dh21 + z__; +/* L20: */ + } + goto L140; +L30: + dh11 = dparam[2]; + dh22 = dparam[5]; + i__2 = nsteps; + i__1 = incx; + for (i__ = 1; i__1 < 0 ? i__ >= i__2 : i__ <= i__2; i__ += i__1) { + w = dx[i__]; + z__ = dy[i__]; + dx[i__] = w * dh11 + z__; + dy[i__] = -w + dh22 * z__; +/* L40: */ + } + goto L140; +L50: + dh11 = dparam[2]; + dh12 = dparam[4]; + dh21 = dparam[3]; + dh22 = dparam[5]; + i__1 = nsteps; + i__2 = incx; + for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) { + w = dx[i__]; + z__ = dy[i__]; + dx[i__] = w * dh11 + z__ * dh12; + dy[i__] = w * dh21 + z__ * dh22; +/* L60: */ + } + goto L140; +L70: + kx = 1; + ky = 1; + if (incx < 0) { + kx = (1 - n) * incx + 1; + } + if (incy < 0) { + ky = (1 - n) * incy + 1; + } + + if (dflag < 0.) { + goto L120; + } else if (dflag == 0) { + goto L80; + } else { + goto L100; + } +L80: + dh12 = dparam[4]; + dh21 = dparam[3]; + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + w = dx[kx]; + z__ = dy[ky]; + dx[kx] = w + z__ * dh12; + dy[ky] = w * dh21 + z__; + kx += incx; + ky += incy; +/* L90: */ + } + goto L140; +L100: + dh11 = dparam[2]; + dh22 = dparam[5]; + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + w = dx[kx]; + z__ = dy[ky]; + dx[kx] = w * dh11 + z__; + dy[ky] = -w + dh22 * z__; + kx += incx; + ky += incy; +/* L110: */ + } + goto L140; +L120: + dh11 = dparam[2]; + dh12 = dparam[4]; + dh21 = dparam[3]; + dh22 = dparam[5]; + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + w = dx[kx]; + z__ = dy[ky]; + dx[kx] = w * dh11 + z__ * dh12; + dy[ky] = w * dh21 + z__ * dh22; + kx += incx; + ky += incy; +/* L130: */ + } +L140: + return(0); +} diff --git a/kernel/riscv64/rotm_rvv.c b/kernel/riscv64/rotm_rvv.c new file mode 100644 index 0000000000..46c678ff63 --- /dev/null +++ b/kernel/riscv64/rotm_rvv.c @@ -0,0 +1,266 @@ +/*************************************************************************** +Copyright (c) 2013, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) __riscv_vsetvl_e32m8(n) +#define FLOAT_V_T vfloat32m8_t +#define VLSEV_FLOAT __riscv_vlse32_v_f32m8 +#define VSSEV_FLOAT __riscv_vsse32_v_f32m8 +#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f32m8 +#define VFMULVF_FLOAT __riscv_vfmul_vf_f32m8 +#define VFMSACVF_FLOAT __riscv_vfmsac_vf_f32m8 +#else +#define VSETVL(n) __riscv_vsetvl_e64m8(n) +#define FLOAT_V_T vfloat64m8_t +#define VLSEV_FLOAT __riscv_vlse64_v_f64m8 +#define VSSEV_FLOAT __riscv_vsse64_v_f64m8 +#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f64m8 +#define VFMULVF_FLOAT __riscv_vfmul_vf_f64m8 +#define VFMSACVF_FLOAT __riscv_vfmsac_vf_f64m8 +#endif + +int CNAME(BLASLONG n, FLOAT *dx, BLASLONG incx, FLOAT *dy, BLASLONG incy, FLOAT *dparam) +{ + BLASLONG i__1, i__2; + BLASLONG kx, ky; + FLOAT dh11, dh12, dh22, dh21, dflag; + BLASLONG nsteps; + + --dparam; + --dy; + --dx; + + FLOAT_V_T v_w, v_z__, v_dx, v_dy; + BLASLONG stride, stride_x, stride_y, offset; + + dflag = dparam[1]; + if (n <= 0 || dflag == - 2.0) goto L140; + + if (!(incx == incy && incx > 0)) goto L70; + + nsteps = n * incx; + if (dflag < 0.) { + goto L50; + } else if (dflag == 0) { + goto L10; + } else { + goto L30; + } +L10: + dh12 = dparam[4]; + dh21 = dparam[3]; + i__1 = nsteps; + i__2 = incx; + if(i__2 < 0){ + offset = i__1 - 2; + dx += offset; + dy += offset; + i__1 = -i__1; + i__2 = -i__2; + } + stride = i__2 * sizeof(FLOAT); + n = i__1 / i__2; + // printf("L10 RVV, i__2: %d, i__1: %d, stride: %d, n: %d \n", i__2, i__1, stride, n); + for (size_t vl; n > 0; n -= vl, dx += vl*i__2, dy += vl*i__2) { + vl = VSETVL(n); + + v_w = VLSEV_FLOAT(&dx[1], stride, vl); + v_z__ = VLSEV_FLOAT(&dy[1], stride, vl); + + v_dx = VFMACCVF_FLOAT(v_w, dh12, v_z__, vl); + v_dy = VFMACCVF_FLOAT(v_z__, dh21, v_w, vl); + + VSSEV_FLOAT(&dx[1], stride, v_dx, vl); + VSSEV_FLOAT(&dy[1], stride, v_dy, vl); + } + goto L140; +L30: + dh11 = dparam[2]; + dh22 = dparam[5]; + i__2 = nsteps; + i__1 = incx; + if(i__1 < 0){ + offset = i__2 - 2; + dx += offset; + dy += offset; + i__1 = -i__1; + i__2 = -i__2; + } + stride = i__1 * sizeof(FLOAT); + n = i__2 / i__1; + // printf("L30 RVV, i__2: %d, i__1: %d, stride: %d, n: %d \n", i__2, i__1, stride, n); + for (size_t vl; n > 0; n -= vl, dx += vl*i__1, dy += vl*i__1) { + vl = VSETVL(n); + + v_w = VLSEV_FLOAT(&dx[1], stride, vl); + v_z__ = VLSEV_FLOAT(&dy[1], stride, vl); + + v_dx = VFMACCVF_FLOAT(v_z__, dh11, v_w, vl); + v_dy = VFMSACVF_FLOAT(v_w, dh22, v_z__, vl); + + VSSEV_FLOAT(&dx[1], stride, v_dx, vl); + VSSEV_FLOAT(&dy[1], stride, v_dy, vl); + } + goto L140; +L50: + dh11 = dparam[2]; + dh12 = dparam[4]; + dh21 = dparam[3]; + dh22 = dparam[5]; + i__1 = nsteps; + i__2 = incx; + if(i__2 < 0){ + offset = i__1 - 2; + dx += offset; + dy += offset; + i__1 = -i__1; + i__2 = -i__2; + } + stride = i__2 * sizeof(FLOAT); + n = i__1 / i__2; + // printf("L50 RVV, i__2: %d, i__1: %d, stride: %d, n: %d \n", i__2, i__1, stride, n); + for (size_t vl; n > 0; n -= vl, dx += vl*i__2, dy += vl*i__2) { + vl = VSETVL(n); + + v_w = VLSEV_FLOAT(&dx[1], stride, vl); + v_z__ = VLSEV_FLOAT(&dy[1], stride, vl); + + v_dx = VFMULVF_FLOAT(v_w, dh11, vl); + v_dx = VFMACCVF_FLOAT(v_dx, dh12, v_z__, vl); + VSSEV_FLOAT(&dx[1], stride, v_dx, vl); + + v_dy = VFMULVF_FLOAT(v_w, dh21, vl); + v_dy = VFMACCVF_FLOAT(v_dy, dh22, v_z__, vl); + VSSEV_FLOAT(&dy[1], stride, v_dy, vl); + } + goto L140; +L70: + kx = 1; + ky = 1; + if (incx < 0) { + kx = (1 - n) * incx + 1; + } + if (incy < 0) { + ky = (1 - n) * incy + 1; + } + + if (dflag < 0.) { + goto L120; + } else if (dflag == 0) { + goto L80; + } else { + goto L100; + } +L80: + dh12 = dparam[4]; + dh21 = dparam[3]; + if(incx < 0){ + incx = -incx; + dx -= n*incx; + } + if(incy < 0){ + incy = -incy; + dy -= n*incy; + } + stride_x = incx * sizeof(FLOAT); + stride_y = incy * sizeof(FLOAT); + // printf("L120 RVV, n: %d, i__1: %d, stride_x: %d, stride_y: %d, n: %d \n", n, i__1, stride_x, stride_y, n); + for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) { + vl = VSETVL(n); + + v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl); + v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl); + + v_dx = VFMACCVF_FLOAT(v_w, dh12, v_z__, vl); + v_dy = VFMACCVF_FLOAT(v_z__, dh21, v_w, vl); + + VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl); + VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl); + } + goto L140; +L100: + dh11 = dparam[2]; + dh22 = dparam[5]; + if(incx < 0){ + incx = -incx; + dx -= n*incx; + } + if(incy < 0){ + incy = -incy; + dy -= n*incy; + } + stride_x = incx * sizeof(FLOAT); + stride_y = incy * sizeof(FLOAT); + // printf("L120 RVV, n: %d, i__1: %d, stride_x: %d, stride_y: %d, n: %d \n", n, i__1, stride_x, stride_y, n); + for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) { + vl = VSETVL(n); + + v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl); + v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl); + + v_dx = VFMACCVF_FLOAT(v_z__, dh11, v_w, vl); + v_dy = VFMSACVF_FLOAT(v_w, dh22, v_z__, vl); + + VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl); + VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl); + } + goto L140; +L120: + dh11 = dparam[2]; + dh12 = dparam[4]; + dh21 = dparam[3]; + dh22 = dparam[5]; + if(incx < 0){ + incx = -incx; + dx -= n*incx; + } + if(incy < 0){ + incy = -incy; + dy -= n*incy; + } + stride_x = incx * sizeof(FLOAT); + stride_y = incy * sizeof(FLOAT); + // printf("L120 RVV, n: %d, i__1: %d, stride_x: %d, stride_y: %d, n: %d \n", n, i__1, stride_x, stride_y, n); + for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) { + vl = VSETVL(n); + + v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl); + v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl); + + v_dx = VFMULVF_FLOAT(v_w, dh11, vl); + v_dx = VFMACCVF_FLOAT(v_dx, dh12, v_z__, vl); + VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl); + + v_dy = VFMULVF_FLOAT(v_w, dh21, vl); + v_dy = VFMACCVF_FLOAT(v_dy, dh22, v_z__, vl); + VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl); + } +L140: + return(0); +} \ No newline at end of file diff --git a/utest/test_rot.c b/utest/test_rot.c index 0e74ecbb36..acd9ff1ce6 100644 --- a/utest/test_rot.c +++ b/utest/test_rot.c @@ -53,6 +53,24 @@ CTEST(rot,drot_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); } } +CTEST(rot,drotm_inc_1) +{ + blasint i = 0; + blasint N = 12, incX = 1, incY = 1; + double param[5] = {1.0, 2.0, 3.0, 4.0, 5.0}; + double x_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}; + double y_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}; + double x_referece[] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0}; + double y_referece[] = {4.0, 8.0, 12.0, 16.0, 20.0, 24.0, 28.0, 32.0, 36.0, 40.0, 44.0, 48.0}; + + //OpenBLAS + BLASFUNC(drotm)(&N, x_actual, &incX, y_actual, &incY, param); + + for(i = 0; i < N; i++){ + ASSERT_DBL_NEAR_TOL(x_referece[i], x_actual[i], DOUBLE_EPS); + ASSERT_DBL_NEAR_TOL(y_referece[i], y_actual[i], DOUBLE_EPS); + } +} #endif #ifdef BUILD_COMPLEX16 @@ -96,6 +114,24 @@ CTEST(rot,srot_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS); } } +CTEST(rot,srotm_inc_1) +{ + blasint i = 0; + blasint N = 12, incX = 1, incY = 1; + float param[5] = {1.0, 2.0, 3.0, 4.0, 5.0}; + float x_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}; + float y_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}; + float x_referece[] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0}; + float y_referece[] = {4.0, 8.0, 12.0, 16.0, 20.0, 24.0, 28.0, 32.0, 36.0, 40.0, 44.0, 48.0}; + + //OpenBLAS + BLASFUNC(srotm)(&N, x_actual, &incX, y_actual, &incY, param); + + for(i = 0; i < N; i++){ + ASSERT_DBL_NEAR_TOL(x_referece[i], x_actual[i], SINGLE_EPS); + ASSERT_DBL_NEAR_TOL(y_referece[i], y_actual[i], SINGLE_EPS); + } +} #endif #ifdef BUILD_COMPLEX