Skip to content

Commit

Permalink
Merge pull request #4720 from martin-frbg/issue3039
Browse files Browse the repository at this point in the history
Resurrect and complete cblas_?gemm_batch
  • Loading branch information
martin-frbg authored May 31, 2024
2 parents 6b564d5 + db070a9 commit 56bd57c
Show file tree
Hide file tree
Showing 11 changed files with 620 additions and 16 deletions.
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ jobs:
mkdir build
cd build
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
cmake -G "Ninja" -DCMAKE_C_COMPILER=cl -DCMAKE_Fortran_COMPILER=flang -DC_LAPACK=1 -DCMAKE_MT=mt -DCMAKE_BUILD_TYPE=Release -DMSVC_STATIC_CRT=ON ..
cmake -G "Ninja" -DCMAKE_C_COMPILER=cl -DCMAKE_Fortran_COMPILER=flang-new -DC_LAPACK=1 -DCMAKE_MT=mt -DCMAKE_BUILD_TYPE=Release -DMSVC_STATIC_CRT=ON ..
cmake --build . --config Release
ctest
Expand Down
15 changes: 15 additions & 0 deletions cblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,18 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta,
double *c, OPENBLAS_CONST blasint cldc);

void cblas_sgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST float ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST float ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);

void cblas_dgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST double * alpha_array, OPENBLAS_CONST double ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST double ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST double * beta_array, double ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);

void cblas_cgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);

void cblas_zgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);

/*** BFLOAT16 and INT8 extensions ***/
/* convert float array to BFLOAT16 array by rounding */
void cblas_sbstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
Expand All @@ -431,6 +443,9 @@ void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum

void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);

#ifdef __cplusplus
}
#endif /* __cplusplus */
Expand Down
7 changes: 6 additions & 1 deletion common_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -1937,8 +1937,13 @@ int zimatcopy_k_rtc(BLASLONG, BLASLONG, double, double, double *, BLASLONG);
int sgeadd_k(BLASLONG, BLASLONG, float, float*, BLASLONG, float, float *, BLASLONG);
int dgeadd_k(BLASLONG, BLASLONG, double, double*, BLASLONG, double, double *, BLASLONG);
int cgeadd_k(BLASLONG, BLASLONG, float, float, float*, BLASLONG, float, float, float *, BLASLONG);
int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG);
int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG);

int sgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);

#ifdef __CUDACC__
}
Expand Down
11 changes: 11 additions & 0 deletions common_macro.h
Original file line number Diff line number Diff line change
Expand Up @@ -2655,9 +2655,20 @@ typedef struct {
BLASLONG prea, preb, prec, pred;
#endif


//for gemm_batch
void * routine;
int routine_mode;

} blas_arg_t;
#endif

#ifdef SMALL_MATRIX_OPT
#define BLAS_SMALL_OPT 0x10000U
#define BLAS_SMALL_B0_OPT 0x30000U
#endif


#ifdef XDOUBLE

#define TRSV_NUU qtrsv_NUU
Expand Down
2 changes: 2 additions & 0 deletions driver/level3/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ if (USE_THREAD)
endif ()

foreach (float_type ${FLOAT_TYPES})
GenerateNamedObjects("gemm_batch_thread.c" "" "gemm_batch_thread" 0 "" "" false ${float_type})

if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")
GenerateCombinationObjects("zherk_kernel.c" "LOWER;CONJ" "U;N" "HERK" 2 "herk_kernel" false ${float_type})
# TRANS needs to be set/unset when CONJ is set/unset, so can't use it as a combination
Expand Down
23 changes: 19 additions & 4 deletions driver/level3/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ SBLASOBJS += \
ssyrk_UN.$(SUFFIX) ssyrk_UT.$(SUFFIX) ssyrk_LN.$(SUFFIX) ssyrk_LT.$(SUFFIX) \
ssyr2k_UN.$(SUFFIX) ssyr2k_UT.$(SUFFIX) ssyr2k_LN.$(SUFFIX) ssyr2k_LT.$(SUFFIX) \
ssyrk_kernel_U.$(SUFFIX) ssyrk_kernel_L.$(SUFFIX) \
ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX)
ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) sgemm_batch_thread.$(SUFFIX)

DBLASOBJS += \
dgemm_nn.$(SUFFIX) dgemm_nt.$(SUFFIX) dgemm_tn.$(SUFFIX) dgemm_tt.$(SUFFIX) \
Expand All @@ -53,7 +53,7 @@ DBLASOBJS += \
dsyrk_UN.$(SUFFIX) dsyrk_UT.$(SUFFIX) dsyrk_LN.$(SUFFIX) dsyrk_LT.$(SUFFIX) \
dsyr2k_UN.$(SUFFIX) dsyr2k_UT.$(SUFFIX) dsyr2k_LN.$(SUFFIX) dsyr2k_LT.$(SUFFIX) \
dsyrk_kernel_U.$(SUFFIX) dsyrk_kernel_L.$(SUFFIX) \
dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX)
dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) dgemm_batch_thread.$(SUFFIX)

QBLASOBJS += \
qgemm_nn.$(SUFFIX) qgemm_nt.$(SUFFIX) qgemm_tn.$(SUFFIX) qgemm_tt.$(SUFFIX) \
Expand Down Expand Up @@ -103,7 +103,7 @@ CBLASOBJS += \
cherk_kernel_LN.$(SUFFIX) cherk_kernel_LC.$(SUFFIX) \
csyr2k_kernel_U.$(SUFFIX) csyr2k_kernel_L.$(SUFFIX) \
cher2k_kernel_UN.$(SUFFIX) cher2k_kernel_UC.$(SUFFIX) \
cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX)
cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) cgemm_batch_thread.$(SUFFIX)

ZBLASOBJS += \
zgemm_nn.$(SUFFIX) zgemm_cn.$(SUFFIX) zgemm_tn.$(SUFFIX) zgemm_nc.$(SUFFIX) \
Expand Down Expand Up @@ -137,7 +137,7 @@ ZBLASOBJS += \
zherk_kernel_LN.$(SUFFIX) zherk_kernel_LC.$(SUFFIX) \
zsyr2k_kernel_U.$(SUFFIX) zsyr2k_kernel_L.$(SUFFIX) \
zher2k_kernel_UN.$(SUFFIX) zher2k_kernel_UC.$(SUFFIX) \
zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX)
zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) zgemm_batch_thread.$(SUFFIX)


XBLASOBJS += \
Expand Down Expand Up @@ -2942,6 +2942,21 @@ gemm_thread_variable.$(PSUFFIX) : gemm_thread_variable.c ../../common.h
beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h
$(CC) -c $(PFLAGS) $< -o $(@F)

sbgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)

sgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)

dgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)

cgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)

zgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)


sbgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
Expand Down
156 changes: 156 additions & 0 deletions driver/level3/gemm_batch_thread.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*****************************************************************************
Copyright (c) 2020, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
**********************************************************************************/

#include "common.h"

void openblas_warning(int verbose, const char * msg);

#ifdef SMALL_MATRIX_OPT
static int inner_small_matrix_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
int routine_mode;
#ifndef COMPLEX
int (*gemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG);
int (*gemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG);
#else
int (*zgemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG);
int (*zgemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG);
FLOAT alpha[2], beta[2];
#endif
routine_mode=args->routine_mode;
if((routine_mode & BLAS_SMALL_B0_OPT) == BLAS_SMALL_B0_OPT){
#ifndef COMPLEX
gemm_small_kernel_b0=args->routine;
gemm_small_kernel_b0(args->m, args->n, args->k, args->a, args->lda, *(FLOAT *)(args->alpha), args->b, args->ldb, args->c, args->ldc);
#else
zgemm_small_kernel_b0=args->routine;
alpha[0] = *((FLOAT *)args -> alpha + 0);
alpha[1] = *((FLOAT *)args -> alpha + 1);
zgemm_small_kernel_b0(args->m, args->n, args->k, args->a, args->lda, alpha[0], alpha[1], args->b, args->ldb, args->c, args->ldc);
#endif
return(0);
}else if(routine_mode & BLAS_SMALL_OPT){
#ifndef COMPLEX
gemm_small_kernel=args->routine;
gemm_small_kernel(args->m, args->n, args->k, args->a, args->lda, *(FLOAT *)(args->alpha), args->b, args->ldb, *(FLOAT *)(args->beta), args->c, args->ldc);
#else
zgemm_small_kernel=args->routine;
alpha[0] = *((FLOAT *)args -> alpha + 0);
alpha[1] = *((FLOAT *)args -> alpha + 1);
beta[0] = *((FLOAT *)args -> beta + 0);
beta[1] = *((FLOAT *)args -> beta + 1);
zgemm_small_kernel(args->m, args->n, args->k, args->a, args->lda, alpha[0], alpha[1], args->b, args->ldb, beta[0], beta[1], args->c, args->ldc);
#endif
return(0);
}
return(1);
}
#endif

int CNAME(blas_arg_t * args_array, BLASLONG nums){
XFLOAT *buffer;
XFLOAT *sa, *sb;
int nthreads=1;
int (*routine)(blas_arg_t *, void *, void *, XFLOAT *, XFLOAT *, BLASLONG);
int i=0, /*j,*/ current_nums;

#ifdef SMP
blas_queue_t * queue=NULL;
#endif

if(nums <=0 ) return 0;

buffer = (XFLOAT *)blas_memory_alloc(0);
sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);

#ifdef SMP
nthreads=num_cpu_avail(3);

if(nthreads==1){

#endif
//single thread
for(i=0; i<nums; i++){
routine=args_array[i].routine;
#ifdef SMALL_MATRIX_OPT
if(args_array[i].routine_mode & BLAS_SMALL_OPT){
inner_small_matrix_thread(&args_array[i], NULL, NULL, NULL, NULL, 0);
}else{
#endif
routine(&args_array[i], NULL, NULL, sa, sb, 0);
#ifdef SMALL_MATRIX_OPT
}
#endif
}
#ifdef SMP
} else {
//multi thread

queue=(blas_queue_t *)malloc((nums+1) * sizeof(blas_queue_t));
if(queue == NULL){
openblas_warning(0, "memory alloc failed!\n");
return(1);
}
for(i=0; i<nums; i++){
queue[i].args=&args_array[i];
queue[i].range_m=NULL;
queue[i].range_n=NULL;
queue[i].sa=NULL;
queue[i].sb=NULL;
queue[i].next=&queue[i+1];

queue[i].mode=args_array[i].routine_mode;
queue[i].routine=args_array[i].routine;

#ifdef SMALL_MATRIX_OPT
if((args_array[i].routine_mode & BLAS_SMALL_B0_OPT) || (args_array[i].routine_mode & BLAS_SMALL_OPT)){
queue[i].routine=inner_small_matrix_thread;
}
#endif
}

for(i=0; i<nums; i+=nthreads){
current_nums=((nums-i)>nthreads)? nthreads: (nums-i);

queue[i].sa=sa;
queue[i].sb=sb;
queue[i+current_nums-1].next=NULL;

exec_blas(current_nums, &queue[i]);
}
free(queue);
}
#endif
blas_memory_free(buffer);
return 0;
}
10 changes: 5 additions & 5 deletions exports/gensymbol
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ cblasobjsc="
cblas_ctbsv cblas_ctpmv cblas_ctpsv cblas_ctrmm cblas_ctrmv cblas_ctrsm cblas_ctrsv
cblas_scnrm2 cblas_scasum cblas_cgemmt
cblas_icamax cblas_icamin cblas_icmin cblas_icmax cblas_scsum cblas_cimatcopy cblas_comatcopy
cblas_caxpyc cblas_crotg cblas_csrot cblas_scamax cblas_scamin
cblas_caxpyc cblas_crotg cblas_csrot cblas_scamax cblas_scamin cblas_cgemm_batch
"
cblasobjsd="
cblas_dasum cblas_daxpy cblas_dcopy cblas_ddot
Expand All @@ -70,7 +70,7 @@ cblasobjsd="
cblas_dsyr2k cblas_dsyr cblas_dsyrk cblas_dtbmv cblas_dtbsv cblas_dtpmv cblas_dtpsv
cblas_dtrmm cblas_dtrmv cblas_dtrsm cblas_dtrsv cblas_daxpby cblas_dgeadd cblas_dgemmt
cblas_idamax cblas_idamin cblas_idmin cblas_idmax cblas_dsum cblas_dimatcopy cblas_domatcopy
cblas_damax cblas_damin
cblas_damax cblas_damin cblas_dgemm_batch
"

cblasobjss="
Expand All @@ -82,7 +82,7 @@ cblasobjss="
cblas_stbmv cblas_stbsv cblas_stpmv cblas_stpsv cblas_strmm cblas_strmv cblas_strsm
cblas_strsv cblas_sgeadd cblas_sgemmt
cblas_isamax cblas_isamin cblas_ismin cblas_ismax cblas_ssum cblas_simatcopy cblas_somatcopy
cblas_samax cblas_samin
cblas_samax cblas_samin cblas_sgemm_batch
"

cblasobjsz="
Expand All @@ -94,12 +94,12 @@ cblasobjsz="
cblas_ztrsv cblas_cdotc_sub cblas_cdotu_sub cblas_zdotc_sub cblas_zdotu_sub
cblas_zaxpby cblas_zgeadd cblas_zgemmt
cblas_izamax cblas_izamin cblas_izmin cblas_izmax cblas_dzsum cblas_zimatcopy cblas_zomatcopy
cblas_zaxpyc cblas_zdrot cblas_zrotg cblas_dzamax cblas_dzamin
cblas_zaxpyc cblas_zdrot cblas_zrotg cblas_dzamax cblas_dzamin cblas_zgemm_batch
"

cblasobjs="cblas_xerbla"

bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod"
bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod cblas_sbgemm_batch"

exblasobjs="
qamax qamin qasum qaxpy qcabs1 qcopy qdot qgbmv qgemm
Expand Down
14 changes: 14 additions & 0 deletions interface/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ foreach (CBLAS_FLAG ${CBLAS_FLAGS})
#sdsdot, dsdot
if (BUILD_SINGLE OR BUILD_DOUBLE)
GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE")
if(CBLAS_FLAG EQUAL 1)
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" false)
endif ()
endif ()
if (BUILD_DOUBLE)
GenerateNamedObjects("dsdot.c" "" "dsdot" ${CBLAS_FLAG} "" "" true "SINGLE")
Expand Down Expand Up @@ -125,6 +128,9 @@ if (BUILD_BFLOAT16)
GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("bf16to.c" "SINGLE_PREC" "sbf16tos" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("bf16to.c" "DOUBLE_PREC" "dbf16tod" ${CBLAS_FLAG} "" "" true "BFLOAT16")
if(CBLAS_FLAG EQUAL 1)
GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16")
endif ()
endif ()

# complex-specific sources
Expand Down Expand Up @@ -154,6 +160,9 @@ foreach (float_type ${FLOAT_TYPES})
GenerateNamedObjects("max.c" "USE_ABS" "scamax" ${CBLAS_FLAG} "" "" true "COMPLEX")
GenerateNamedObjects("asum.c" "" "scasum" ${CBLAS_FLAG} "" "" true "COMPLEX")
GenerateNamedObjects("sum.c" "" "scsum" ${CBLAS_FLAG} "" "" true "COMPLEX")
if(CBLAS_FLAG EQUAL 1)
GenerateNamedObjects("gemm_batch.c" "" "cgemm_batch" ${CBLAS_FLAG} "" "" true "COMPLEX")
endif ()
endif ()
if (${float_type} STREQUAL "ZCOMPLEX")
GenerateNamedObjects("zscal.c" "SSCAL" "dscal" ${CBLAS_FLAG} "" "" false "ZCOMPLEX")
Expand All @@ -163,6 +172,9 @@ foreach (float_type ${FLOAT_TYPES})
GenerateNamedObjects("max.c" "USE_ABS" "dzamax" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
GenerateNamedObjects("asum.c" "" "dzasum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
GenerateNamedObjects("sum.c" "" "dzsum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
if(CBLAS_FLAG EQUAL 1)
GenerateNamedObjects("gemm_batch.c" "" "zgemm_batch" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
endif ()
endif ()
endforeach ()

Expand Down Expand Up @@ -212,6 +224,7 @@ if ( BUILD_COMPLEX AND NOT BUILD_SINGLE)
GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "SINGLE")
GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "SINGLE")
GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "SINGLE")
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "SINGLE")
GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "SINGLE")
GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "SINGLE")
GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "SINGLE")
Expand All @@ -225,6 +238,7 @@ if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE)
GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "DOUBLE")
GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "DOUBLE")
GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "DOUBLE")
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "DOUBLE")
GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "DOUBLE")
GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "DOUBLE")
GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "DOUBLE")
Expand Down
Loading

0 comments on commit 56bd57c

Please sign in to comment.