Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the rotm kernel with RVV intrinsic. #5038

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions common_riscv64.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ static inline int blas_quickdivide(blasint x, blasint y){

#if defined(C910V) || defined(RISCV64_ZVL256B) || defined(RISCV64_ZVL128B) || defined(x280)
# include <riscv_vector.h>
#if defined(x280)
#define RISCV_SIMD
#endif
#endif

#if defined( __riscv_xtheadc ) && defined( __riscv_v ) && ( __riscv_v <= 7000 )
Expand Down
176 changes: 176 additions & 0 deletions interface/rotm.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@
#include "functable.h"
#endif

#if defined(RISCV_SIMD)
#if !defined(DOUBLE)
#define VSETVL(n) __riscv_vsetvl_e32m8(n)
#define FLOAT_V_T vfloat32m8_t
#define VLSEV_FLOAT __riscv_vlse32_v_f32m8
#define VSSEV_FLOAT __riscv_vsse32_v_f32m8
#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f32m8
#define VFMULVF_FLOAT __riscv_vfmul_vf_f32m8
#define VFMSACVF_FLOAT __riscv_vfmsac_vf_f32m8
#else
#define VSETVL(n) __riscv_vsetvl_e64m8(n)
#define FLOAT_V_T vfloat64m8_t
#define VLSEV_FLOAT __riscv_vlse64_v_f64m8
#define VSSEV_FLOAT __riscv_vsse64_v_f64m8
#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f64m8
#define VFMULVF_FLOAT __riscv_vfmul_vf_f64m8
#define VFMSACVF_FLOAT __riscv_vfmsac_vf_f64m8
#endif
#endif

#ifndef CBLAS

void NAME(blasint *N, FLOAT *dx, blasint *INCX, FLOAT *dy, blasint *INCY, FLOAT *dparam){
Expand All @@ -25,6 +45,11 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
FLOAT dh11, dh12, dh22, dh21, dflag;
blasint nsteps;

#if defined(RISCV_SIMD)
FLOAT_V_T v_w, v_z__, v_dx, v_dy;
blasint stride, stride_x, stride_y, offset;
#endif

#ifndef CBLAS
PRINT_DEBUG_CNAME;
#else
Expand Down Expand Up @@ -53,26 +78,74 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
dh21 = dparam[3];
i__1 = nsteps;
i__2 = incx;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) {
w = dx[i__];
z__ = dy[i__];
dx[i__] = w + z__ * dh12;
dy[i__] = w * dh21 + z__;
/* L20: */
}
#else
if(i__2 < 0){
offset = i__1 - 2;
dx += offset;
dy += offset;
i__1 = -i__1;
i__2 = -i__2;
}
stride = i__2 * sizeof(FLOAT);
n = i__1 / i__2;
for (size_t vl; n > 0; n -= vl, dx += vl*i__2, dy += vl*i__2) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[1], stride, vl);
v_z__ = VLSEV_FLOAT(&dy[1], stride, vl);

v_dx = VFMACCVF_FLOAT(v_w, dh12, v_z__, vl);
v_dy = VFMACCVF_FLOAT(v_z__, dh21, v_w, vl);

VSSEV_FLOAT(&dx[1], stride, v_dx, vl);
VSSEV_FLOAT(&dy[1], stride, v_dy, vl);
}
#endif
goto L140;
L30:
dh11 = dparam[2];
dh22 = dparam[5];
i__2 = nsteps;
i__1 = incx;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__1 < 0 ? i__ >= i__2 : i__ <= i__2; i__ += i__1) {
w = dx[i__];
z__ = dy[i__];
dx[i__] = w * dh11 + z__;
dy[i__] = -w + dh22 * z__;
/* L40: */
}
#else
if(i__1 < 0){
offset = i__2 - 2;
dx += offset;
dy += offset;
i__1 = -i__1;
i__2 = -i__2;
}
stride = i__1 * sizeof(FLOAT);
n = i__2 / i__1;
for (size_t vl; n > 0; n -= vl, dx += vl*i__1, dy += vl*i__1) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[1], stride, vl);
v_z__ = VLSEV_FLOAT(&dy[1], stride, vl);

v_dx = VFMACCVF_FLOAT(v_z__, dh11, v_w, vl);
v_dy = VFMSACVF_FLOAT(v_w, dh22, v_z__, vl);

VSSEV_FLOAT(&dx[1], stride, v_dx, vl);
VSSEV_FLOAT(&dy[1], stride, v_dy, vl);
}
#endif
goto L140;
L50:
dh11 = dparam[2];
Expand All @@ -81,13 +154,39 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
dh22 = dparam[5];
i__1 = nsteps;
i__2 = incx;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) {
w = dx[i__];
z__ = dy[i__];
dx[i__] = w * dh11 + z__ * dh12;
dy[i__] = w * dh21 + z__ * dh22;
/* L60: */
}
#else
if(i__2 < 0){
offset = i__1 - 2;
dx += offset;
dy += offset;
i__1 = -i__1;
i__2 = -i__2;
}
stride = i__2 * sizeof(FLOAT);
n = i__1 / i__2;
for (size_t vl; n > 0; n -= vl, dx += vl*i__2, dy += vl*i__2) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[1], stride, vl);
v_z__ = VLSEV_FLOAT(&dy[1], stride, vl);

v_dx = VFMULVF_FLOAT(v_w, dh11, vl);
v_dx = VFMACCVF_FLOAT(v_dx, dh12, v_z__, vl);
VSSEV_FLOAT(&dx[1], stride, v_dx, vl);

v_dy = VFMULVF_FLOAT(v_w, dh21, vl);
v_dy = VFMACCVF_FLOAT(v_dy, dh22, v_z__, vl);
VSSEV_FLOAT(&dy[1], stride, v_dy, vl);
}
#endif
goto L140;
L70:
kx = 1;
Expand All @@ -110,6 +209,7 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
dh12 = dparam[4];
dh21 = dparam[3];
i__2 = n;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__ <= i__2; ++i__) {
w = dx[kx];
z__ = dy[ky];
Expand All @@ -119,11 +219,36 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
ky += incy;
/* L90: */
}
#else
if(incx < 0){
incx = -incx;
dx -= n*incx;
}
if(incy < 0){
incy = -incy;
dy -= n*incy;
}
stride_x = incx * sizeof(FLOAT);
stride_y = incy * sizeof(FLOAT);
for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl);
v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl);

v_dx = VFMACCVF_FLOAT(v_w, dh12, v_z__, vl);
v_dy = VFMACCVF_FLOAT(v_z__, dh21, v_w, vl);

VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl);
VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl);
}
#endif
goto L140;
L100:
dh11 = dparam[2];
dh22 = dparam[5];
i__2 = n;
#if !defined(RISCV_SIMD)
for (i__ = 1; i__ <= i__2; ++i__) {
w = dx[kx];
z__ = dy[ky];
Expand All @@ -133,8 +258,33 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
ky += incy;
/* L110: */
}
#else
if(incx < 0){
incx = -incx;
dx -= n*incx;
}
if(incy < 0){
incy = -incy;
dy -= n*incy;
}
stride_x = incx * sizeof(FLOAT);
stride_y = incy * sizeof(FLOAT);
for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl);
v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl);

v_dx = VFMACCVF_FLOAT(v_z__, dh11, v_w, vl);
v_dy = VFMSACVF_FLOAT(v_w, dh22, v_z__, vl);

VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl);
VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl);
}
#endif
goto L140;
L120:
#if !defined(RISCV_SIMD)
dh11 = dparam[2];
dh12 = dparam[4];
dh21 = dparam[3];
Expand All @@ -149,6 +299,32 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
ky += incy;
/* L130: */
}
#else
if(incx < 0){
incx = -incx;
dx -= n*incx;
}
if(incy < 0){
incy = -incy;
dy -= n*incy;
}
stride_x = incx * sizeof(FLOAT);
stride_y = incy * sizeof(FLOAT);
for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) {
vl = VSETVL(n);

v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl);
v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl);

v_dx = VFMULVF_FLOAT(v_w, dh11, vl);
v_dx = VFMACCVF_FLOAT(v_dx, dh12, v_z__, vl);
VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl);

v_dy = VFMULVF_FLOAT(v_w, dh21, vl);
v_dy = VFMACCVF_FLOAT(v_dy, dh22, v_z__, vl);
VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl);
}
#endif
L140:
return;
}
Expand Down
72 changes: 72 additions & 0 deletions utest/test_rot.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,42 @@ CTEST(rot,drot_inc_0)
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS);
}
}
CTEST(rot,drot_inc_1)
{
blasint i=0;
blasint N=4,incX=1,incY=1;
double c=1.0,s=1.0;
double x1[]={1.0,3.0,5.0,7.0};
double y1[]={2.0,4.0,6.0,8.0};
double x2[]={3.0,7.0,11.0,15.0};
double y2[]={1.0,1.0,1.0,1.0};

//OpenBLAS
BLASFUNC(drot)(&N,x1,&incX,y1,&incY,&c,&s);

for(i=0; i<N; i++){
ASSERT_DBL_NEAR_TOL(x2[i], x1[i], DOUBLE_EPS);
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS);
}
}
CTEST(rot,drotm_inc_1)
{
blasint i = 0;
blasint N = 12, incX = 1, incY = 1;
double param[5] = {1.0, 2.0, 3.0, 4.0, 5.0};
double x_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
double y_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
double x_referece[] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0};
double y_referece[] = {4.0, 8.0, 12.0, 16.0, 20.0, 24.0, 28.0, 32.0, 36.0, 40.0, 44.0, 48.0};

//OpenBLAS
BLASFUNC(drotm)(&N, x_actual, &incX, y_actual, &incY, param);

for(i = 0; i < N; i++){
ASSERT_DBL_NEAR_TOL(x_referece[i], x_actual[i], DOUBLE_EPS);
ASSERT_DBL_NEAR_TOL(y_referece[i], y_actual[i], DOUBLE_EPS);
}
}
#endif

#ifdef BUILD_COMPLEX16
Expand Down Expand Up @@ -96,6 +132,42 @@ CTEST(rot,srot_inc_0)
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS);
}
}
CTEST(rot,srot_inc_1)
{
blasint i=0;
blasint N=4,incX=1,incY=1;
float c=1.0,s=1.0;
float x1[]={1.0,3.0,5.0,7.0};
float y1[]={2.0,4.0,6.0,8.0};
float x2[]={3.0,7.0,11.0,15.0};
float y2[]={1.0,1.0,1.0,1.0};

//OpenBLAS
BLASFUNC(srot)(&N,x1,&incX,y1,&incY,&c,&s);

for(i=0; i<N; i++){
ASSERT_DBL_NEAR_TOL(x2[i], x1[i], SINGLE_EPS);
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS);
}
}
CTEST(rot,srotm_inc_1)
{
blasint i = 0;
blasint N = 12, incX = 1, incY = 1;
float param[5] = {1.0, 2.0, 3.0, 4.0, 5.0};
float x_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
float y_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
float x_referece[] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0};
float y_referece[] = {4.0, 8.0, 12.0, 16.0, 20.0, 24.0, 28.0, 32.0, 36.0, 40.0, 44.0, 48.0};

//OpenBLAS
BLASFUNC(srotm)(&N, x_actual, &incX, y_actual, &incY, param);

for(i = 0; i < N; i++){
ASSERT_DBL_NEAR_TOL(x_referece[i], x_actual[i], SINGLE_EPS);
ASSERT_DBL_NEAR_TOL(y_referece[i], y_actual[i], SINGLE_EPS);
}
}
#endif

#ifdef BUILD_COMPLEX
Expand Down
Loading