diff --git a/Makefile b/Makefile index 86769877..bdbb9474 100644 --- a/Makefile +++ b/Makefile @@ -148,7 +148,8 @@ generated-sources: \ $(LAPACK)/ilaenv.f \ $(LAPACK)/[sd]geqrf.f $(LAPACK)/[sd]ormqr.f \ $(LAPACK)/[sd]orgqr.f \ - $(LAPACK)/[sd]sygvx.f + $(LAPACK)/[sd]sygvx.f \ + $(LAPACK)/[cz]heev.f ant javah touch $@ diff --git a/src/main/c/NativeBlas.c b/src/main/c/NativeBlas.c index 083e3cfc..72c098c2 100644 --- a/src/main/c/NativeBlas.c +++ b/src/main/c/NativeBlas.c @@ -102,7 +102,7 @@ static ComplexDouble getComplexDouble(JNIEnv *env, jobject dc) /**********************************************************************/ static char *routine_names[] = { - "CAXPY", "CCOPY", "CDOTC", "CDOTU", "CGEEV", "CGEMM", "CGEMV", "CGERC", "CGERU", "CGESVD", "CSCAL", "CSSCAL", "CSWAP", "DASUM", "DAXPY", "DCOPY", "DDOT", "DGEEV", "DGELSD", "DGEMM", "DGEMV", "DGEQRF", "DGER", "DGESV", "DGESVD", "DGETRF", "DNRM2", "DORGQR", "DORMQR", "DPOSV", "DPOTRF", "DSCAL", "DSWAP", "DSYEV", "DSYEVD", "DSYEVR", "DSYEVX", "DSYGVD", "DSYGVX", "DSYSV", "DZASUM", "DZNRM2", "ICAMAX", "IDAMAX", "ILAENV", "ISAMAX", "IZAMAX", "SASUM", "SAXPY", "SCASUM", "SCNRM2", "SCOPY", "SDOT", "SGEEV", "SGELSD", "SGEMM", "SGEMV", "SGEQRF", "SGER", "SGESV", "SGESVD", "SGETRF", "SNRM2", "SORGQR", "SORMQR", "SPOSV", "SPOTRF", "SSCAL", "SSWAP", "SSYEV", "SSYEVD", "SSYEVR", "SSYEVX", "SSYGVD", "SSYGVX", "SSYSV", "ZAXPY", "ZCOPY", "ZDOTC", "ZDOTU", "ZDSCAL", "ZGEEV", "ZGEMM", "ZGEMV", "ZGERC", "ZGERU", "ZGESVD", "ZSCAL", "ZSWAP", 0 + "CAXPY", "CCOPY", "CDOTC", "CDOTU", "CGEEV", "CGEMM", "CGEMV", "CGERC", "CGERU", "CGESVD", "CHEEV", "CSCAL", "CSSCAL", "CSWAP", "DASUM", "DAXPY", "DCOPY", "DDOT", "DGEEV", "DGELSD", "DGEMM", "DGEMV", "DGEQRF", "DGER", "DGESV", "DGESVD", "DGETRF", "DNRM2", "DORGQR", "DORMQR", "DPOSV", "DPOTRF", "DSCAL", "DSWAP", "DSYEV", "DSYEVD", "DSYEVR", "DSYEVX", "DSYGVD", "DSYGVX", "DSYSV", "DZASUM", "DZNRM2", "ICAMAX", "IDAMAX", "ILAENV", "ISAMAX", "IZAMAX", "SASUM", "SAXPY", "SCASUM", "SCNRM2", "SCOPY", "SDOT", "SGEEV", "SGELSD", "SGEMM", "SGEMV", "SGEQRF", "SGER", "SGESV", "SGESVD", "SGETRF", "SNRM2", "SORGQR", "SORMQR", "SPOSV", "SPOTRF", "SSCAL", "SSWAP", "SSYEV", "SSYEVD", "SSYEVR", "SSYEVX", "SSYGVD", "SSYGVX", "SSYSV", "ZAXPY", "ZCOPY", "ZDOTC", "ZDOTU", "ZDSCAL", "ZGEEV", "ZGEMM", "ZGEMV", "ZGERC", "ZGERU", "ZGESVD", "ZHEEV", "ZSCAL", "ZSWAP", 0 }; static char *routine_arguments[][23] = { @@ -116,6 +116,7 @@ static char *routine_arguments[][23] = { { "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" }, { "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" }, { "JOBU", "JOBVT", "M", "N", "A", "LDA", "S", "U", "LDU", "VT", "LDVT", "WORK", "LWORK", "RWORK", "INFO" }, + { "JOBZ", "UPLO", "N", "A", "LDA", "W", "WORK", "LWORK", "RWORK", "INFO" }, { "N", "CA", "CX", "INCX" }, { "N", "SA", "CX", "INCX" }, { "N", "CX", "INCX", "CY", "INCY" }, @@ -193,6 +194,7 @@ static char *routine_arguments[][23] = { { "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" }, { "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" }, { "JOBU", "JOBVT", "M", "N", "A", "LDA", "S", "U", "LDU", "VT", "LDVT", "WORK", "LWORK", "RWORK", "INFO" }, + { "JOBZ", "UPLO", "N", "A", "LDA", "W", "WORK", "LWORK", "RWORK", "INFO" }, { "N", "ZA", "ZX", "INCX" }, { "N", "ZX", "INCX", "ZY", "INCY" }, }; @@ -5167,3 +5169,161 @@ JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_ssygvx(JNIEnv *env, jclass this return info; } +JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_cheev(JNIEnv *env, jclass this, jchar jobz, jchar uplo, jint n, jfloatArray a, jint aIdx, jint lda, jfloatArray w, jint wIdx, jfloatArray work, jint workIdx, jint lwork, jfloatArray rwork, jint rworkIdx) +{ + extern void cheev_(char *, char *, jint *, jfloat *, jint *, jfloat *, jfloat *, jint *, jfloat *, int *); + + char jobzChr = (char) jobz; + char uploChr = (char) uplo; + jfloat *rworkPtrBase = 0, *rworkPtr = 0; + if (rwork) { + rworkPtrBase = (*env)->GetFloatArrayElements(env, rwork, NULL); + rworkPtr = rworkPtrBase + rworkIdx; + } + jfloat *aPtrBase = 0, *aPtr = 0; + if (a) { + if((*env)->IsSameObject(env, a, rwork) == JNI_TRUE) + aPtrBase = rworkPtrBase; + else + aPtrBase = (*env)->GetFloatArrayElements(env, a, NULL); + aPtr = aPtrBase + 2*aIdx; + } + jfloat *wPtrBase = 0, *wPtr = 0; + if (w) { + if((*env)->IsSameObject(env, w, rwork) == JNI_TRUE) + wPtrBase = rworkPtrBase; + else + if((*env)->IsSameObject(env, w, a) == JNI_TRUE) + wPtrBase = aPtrBase; + else + wPtrBase = (*env)->GetFloatArrayElements(env, w, NULL); + wPtr = wPtrBase + wIdx; + } + jfloat *workPtrBase = 0, *workPtr = 0; + if (work) { + if((*env)->IsSameObject(env, work, rwork) == JNI_TRUE) + workPtrBase = rworkPtrBase; + else + if((*env)->IsSameObject(env, work, a) == JNI_TRUE) + workPtrBase = aPtrBase; + else + if((*env)->IsSameObject(env, work, w) == JNI_TRUE) + workPtrBase = wPtrBase; + else + workPtrBase = (*env)->GetFloatArrayElements(env, work, NULL); + workPtr = workPtrBase + 2*workIdx; + } + int info; + + cheev_(&jobzChr, &uploChr, &n, aPtr, &lda, wPtr, workPtr, &lwork, rworkPtr, &info); + if(workPtrBase) { + (*env)->ReleaseFloatArrayElements(env, work, workPtrBase, 0); + if (workPtrBase == rworkPtrBase) + rworkPtrBase = 0; + if (workPtrBase == aPtrBase) + aPtrBase = 0; + if (workPtrBase == wPtrBase) + wPtrBase = 0; + workPtrBase = 0; + } + if(wPtrBase) { + (*env)->ReleaseFloatArrayElements(env, w, wPtrBase, 0); + if (wPtrBase == rworkPtrBase) + rworkPtrBase = 0; + if (wPtrBase == aPtrBase) + aPtrBase = 0; + wPtrBase = 0; + } + if(aPtrBase) { + (*env)->ReleaseFloatArrayElements(env, a, aPtrBase, 0); + if (aPtrBase == rworkPtrBase) + rworkPtrBase = 0; + aPtrBase = 0; + } + if(rworkPtrBase) { + (*env)->ReleaseFloatArrayElements(env, rwork, rworkPtrBase, JNI_ABORT); + rworkPtrBase = 0; + } + + return info; +} + +JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_zheev(JNIEnv *env, jclass this, jchar jobz, jchar uplo, jint n, jdoubleArray a, jint aIdx, jint lda, jdoubleArray w, jint wIdx, jdoubleArray work, jint workIdx, jint lwork, jdoubleArray rwork, jint rworkIdx) +{ + extern void zheev_(char *, char *, jint *, jdouble *, jint *, jdouble *, jdouble *, jint *, jdouble *, int *); + + char jobzChr = (char) jobz; + char uploChr = (char) uplo; + jdouble *rworkPtrBase = 0, *rworkPtr = 0; + if (rwork) { + rworkPtrBase = (*env)->GetDoubleArrayElements(env, rwork, NULL); + rworkPtr = rworkPtrBase + rworkIdx; + } + jdouble *aPtrBase = 0, *aPtr = 0; + if (a) { + if((*env)->IsSameObject(env, a, rwork) == JNI_TRUE) + aPtrBase = rworkPtrBase; + else + aPtrBase = (*env)->GetDoubleArrayElements(env, a, NULL); + aPtr = aPtrBase + 2*aIdx; + } + jdouble *wPtrBase = 0, *wPtr = 0; + if (w) { + if((*env)->IsSameObject(env, w, rwork) == JNI_TRUE) + wPtrBase = rworkPtrBase; + else + if((*env)->IsSameObject(env, w, a) == JNI_TRUE) + wPtrBase = aPtrBase; + else + wPtrBase = (*env)->GetDoubleArrayElements(env, w, NULL); + wPtr = wPtrBase + wIdx; + } + jdouble *workPtrBase = 0, *workPtr = 0; + if (work) { + if((*env)->IsSameObject(env, work, rwork) == JNI_TRUE) + workPtrBase = rworkPtrBase; + else + if((*env)->IsSameObject(env, work, a) == JNI_TRUE) + workPtrBase = aPtrBase; + else + if((*env)->IsSameObject(env, work, w) == JNI_TRUE) + workPtrBase = wPtrBase; + else + workPtrBase = (*env)->GetDoubleArrayElements(env, work, NULL); + workPtr = workPtrBase + 2*workIdx; + } + int info; + + zheev_(&jobzChr, &uploChr, &n, aPtr, &lda, wPtr, workPtr, &lwork, rworkPtr, &info); + if(workPtrBase) { + (*env)->ReleaseDoubleArrayElements(env, work, workPtrBase, 0); + if (workPtrBase == rworkPtrBase) + rworkPtrBase = 0; + if (workPtrBase == aPtrBase) + aPtrBase = 0; + if (workPtrBase == wPtrBase) + wPtrBase = 0; + workPtrBase = 0; + } + if(wPtrBase) { + (*env)->ReleaseDoubleArrayElements(env, w, wPtrBase, 0); + if (wPtrBase == rworkPtrBase) + rworkPtrBase = 0; + if (wPtrBase == aPtrBase) + aPtrBase = 0; + wPtrBase = 0; + } + if(aPtrBase) { + (*env)->ReleaseDoubleArrayElements(env, a, aPtrBase, 0); + if (aPtrBase == rworkPtrBase) + rworkPtrBase = 0; + aPtrBase = 0; + } + if(rworkPtrBase) { + (*env)->ReleaseDoubleArrayElements(env, rwork, rworkPtrBase, JNI_ABORT); + rworkPtrBase = 0; + } + + return info; +} + diff --git a/src/main/c/org_jblas_NativeBlas.h b/src/main/c/org_jblas_NativeBlas.h index d2efd790..2ef7933d 100644 --- a/src/main/c/org_jblas_NativeBlas.h +++ b/src/main/c/org_jblas_NativeBlas.h @@ -723,6 +723,22 @@ JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_dsygvx JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_ssygvx (JNIEnv *, jclass, jint, jchar, jchar, jchar, jint, jfloatArray, jint, jint, jfloatArray, jint, jint, jfloat, jfloat, jint, jint, jfloat, jintArray, jint, jfloatArray, jint, jfloatArray, jint, jint, jfloatArray, jint, jint, jintArray, jint, jintArray, jint); +/* + * Class: org_jblas_NativeBlas + * Method: cheev + * Signature: (CCI[FII[FI[FII[FI)I + */ +JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_cheev + (JNIEnv *, jclass, jchar, jchar, jint, jfloatArray, jint, jint, jfloatArray, jint, jfloatArray, jint, jint, jfloatArray, jint); + +/* + * Class: org_jblas_NativeBlas + * Method: zheev + * Signature: (CCI[DII[DI[DII[DI)I + */ +JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_zheev + (JNIEnv *, jclass, jchar, jchar, jint, jdoubleArray, jint, jint, jdoubleArray, jint, jdoubleArray, jint, jint, jdoubleArray, jint); + #ifdef __cplusplus } #endif diff --git a/src/main/java/org/jblas/Eigen.java b/src/main/java/org/jblas/Eigen.java index 7a03f831..0fbfd847 100644 --- a/src/main/java/org/jblas/Eigen.java +++ b/src/main/java/org/jblas/Eigen.java @@ -51,6 +51,7 @@ */ public class Eigen { private static final DoubleMatrix dummyDouble = new DoubleMatrix(1); + private static final ComplexDoubleMatrix dummyComplexDouble = new ComplexDoubleMatrix(1); /** * Compute the eigenvalues for a symmetric matrix. @@ -87,7 +88,6 @@ public static ComplexDoubleMatrix eigenvalues(DoubleMatrix A) { DoubleMatrix WR = new DoubleMatrix(A.rows); DoubleMatrix WI = WR.dup(); SimpleBlas.geev('N', 'N', A.dup(), WR, WI, dummyDouble, dummyDouble); - return new ComplexDoubleMatrix(WR, WI); } @@ -306,10 +306,70 @@ public static DoubleMatrix[] symmetricGeneralizedEigenvectors(DoubleMatrix A, Do return result; } + /** + * Computes the eigenvalues of a complex matrix. + */ + public static ComplexDoubleMatrix eigenvalues(ComplexDoubleMatrix A) { + A.assertSquare(); + ComplexDoubleMatrix W = new ComplexDoubleMatrix(A.rows); + SimpleBlas.cgeev('N', 'N', A.dup(), W, dummyComplexDouble, dummyComplexDouble); + return W; + } + + /** + * Computes the eigenvalues and eigenvectors of a complex matrix. + * + * @return an array of ComplexDoubleMatrix objects containing the (right) eigenvectors + * stored as the columns of the first matrix, and the eigenvalues as the + * diagonal elements of the second matrix. + */ + public static ComplexDoubleMatrix[] eigenvectors(ComplexDoubleMatrix A) { + A.assertSquare(); + // setting up result arrays + ComplexDoubleMatrix W = new ComplexDoubleMatrix(A.rows); + ComplexDoubleMatrix VR = new ComplexDoubleMatrix(A.rows, A.rows); + + SimpleBlas.cgeev('N', 'V', A.dup(), W, dummyComplexDouble, VR); + return new ComplexDoubleMatrix[]{VR, ComplexDoubleMatrix.diag(W)}; +} + + /** + * Computes the eigenvalues of a complex Hermitian matrix. + * + * Assumes that the input is an Hermitian matrix. + */ + public static DoubleMatrix hermitianEigenvalues(ComplexDoubleMatrix A) { + A.assertSquare(); + DoubleMatrix W = new DoubleMatrix(A.rows); + SimpleBlas.cheev('N', 'U', A.dup(), W); + return W; + } + + /** + * Computes the eigenvalues and eigenvectors of a complex Hermitian matrix. + * + * Assumes that the input is an Hermitian matrix. + * + * @return an array of ComplexDoubleMatrix objects containing the orthonormal eigenvectors + * stored as the columns of the first matrix, and the eigenvalues (in ascending order) + * as the diagonal elements of the second matrix. + */ + public static ComplexDoubleMatrix[] hermitianEigenvectors(ComplexDoubleMatrix A) { + A.assertSquare(); + // setting up result arrays + DoubleMatrix W = new DoubleMatrix(A.rows); + ComplexDoubleMatrix eigenvectors = A.dup(); + + SimpleBlas.cheev('V', 'U', eigenvectors, W); + return new ComplexDoubleMatrix[]{eigenvectors, ComplexDoubleMatrix.diag(W.toComplex())}; + } + + //BEGIN // The code below has been automatically generated. // DO NOT EDIT! private static final FloatMatrix dummyFloat = new FloatMatrix(1); + private static final ComplexFloatMatrix dummyComplexFloat = new ComplexFloatMatrix(1); /** * Compute the eigenvalues for a symmetric matrix. @@ -346,7 +406,6 @@ public static ComplexFloatMatrix eigenvalues(FloatMatrix A) { FloatMatrix WR = new FloatMatrix(A.rows); FloatMatrix WI = WR.dup(); SimpleBlas.geev('N', 'N', A.dup(), WR, WI, dummyFloat, dummyFloat); - return new ComplexFloatMatrix(WR, WI); } @@ -565,5 +624,64 @@ public static FloatMatrix[] symmetricGeneralizedEigenvectors(FloatMatrix A, Floa return result; } + /** + * Computes the eigenvalues of a complex matrix. + */ + public static ComplexFloatMatrix eigenvalues(ComplexFloatMatrix A) { + A.assertSquare(); + ComplexFloatMatrix W = new ComplexFloatMatrix(A.rows); + SimpleBlas.cgeev('N', 'N', A.dup(), W, dummyComplexFloat, dummyComplexFloat); + return W; + } + + /** + * Computes the eigenvalues and eigenvectors of a complex matrix. + * + * @return an array of ComplexFloatMatrix objects containing the (right) eigenvectors + * stored as the columns of the first matrix, and the eigenvalues as the + * diagonal elements of the second matrix. + */ + public static ComplexFloatMatrix[] eigenvectors(ComplexFloatMatrix A) { + A.assertSquare(); + // setting up result arrays + ComplexFloatMatrix W = new ComplexFloatMatrix(A.rows); + ComplexFloatMatrix VR = new ComplexFloatMatrix(A.rows, A.rows); + + SimpleBlas.cgeev('N', 'V', A.dup(), W, dummyComplexFloat, VR); + return new ComplexFloatMatrix[]{VR, ComplexFloatMatrix.diag(W)}; +} + + /** + * Computes the eigenvalues of a complex Hermitian matrix. + * + * Assumes that the input is an Hermitian matrix. + */ + public static FloatMatrix hermitianEigenvalues(ComplexFloatMatrix A) { + A.assertSquare(); + FloatMatrix W = new FloatMatrix(A.rows); + SimpleBlas.cheev('N', 'U', A.dup(), W); + return W; + } + + /** + * Computes the eigenvalues and eigenvectors of a complex Hermitian matrix. + * + * Assumes that the input is an Hermitian matrix. + * + * @return an array of ComplexFloatMatrix objects containing the orthonormal eigenvectors + * stored as the columns of the first matrix, and the eigenvalues (in ascending order) + * as the diagonal elements of the second matrix. + */ + public static ComplexFloatMatrix[] hermitianEigenvectors(ComplexFloatMatrix A) { + A.assertSquare(); + // setting up result arrays + FloatMatrix W = new FloatMatrix(A.rows); + ComplexFloatMatrix eigenvectors = A.dup(); + + SimpleBlas.cheev('V', 'U', eigenvectors, W); + return new ComplexFloatMatrix[]{eigenvectors, ComplexFloatMatrix.diag(W.toComplex())}; + } + + //END } diff --git a/src/main/java/org/jblas/NativeBlas.java b/src/main/java/org/jblas/NativeBlas.java index 332d370e..047c8d92 100644 --- a/src/main/java/org/jblas/NativeBlas.java +++ b/src/main/java/org/jblas/NativeBlas.java @@ -548,5 +548,31 @@ public static int ssygvx(int itype, char jobz, char range, char uplo, int n, flo return info; } + public static native int cheev(char jobz, char uplo, int n, float[] a, int aIdx, int lda, float[] w, int wIdx, float[] work, int workIdx, int lwork, float[] rwork, int rworkIdx); + public static int cheev(char jobz, char uplo, int n, float[] a, int aIdx, int lda, float[] w, int wIdx, float[] rwork, int rworkIdx) { + int info; + float[] work = new float[1*2]; + int lwork; + info = cheev(jobz, uplo, n, floatDummy, 0, lda, floatDummy, 0, work, 0, -1, floatDummy, 0); + if (info != 0) + return info; + lwork = (int) work[0]; work = new float[lwork*2]; + info = cheev(jobz, uplo, n, a, aIdx, lda, w, wIdx, work, 0, lwork, rwork, rworkIdx); + return info; + } + + public static native int zheev(char jobz, char uplo, int n, double[] a, int aIdx, int lda, double[] w, int wIdx, double[] work, int workIdx, int lwork, double[] rwork, int rworkIdx); + public static int zheev(char jobz, char uplo, int n, double[] a, int aIdx, int lda, double[] w, int wIdx, double[] rwork, int rworkIdx) { + int info; + double[] work = new double[1*2]; + int lwork; + info = zheev(jobz, uplo, n, doubleDummy, 0, lda, doubleDummy, 0, work, 0, -1, doubleDummy, 0); + if (info != 0) + return info; + lwork = (int) work[0]; work = new double[lwork*2]; + info = zheev(jobz, uplo, n, a, aIdx, lda, w, wIdx, work, 0, lwork, rwork, rworkIdx); + return info; + } + } diff --git a/src/main/java/org/jblas/SimpleBlas.java b/src/main/java/org/jblas/SimpleBlas.java index 3240b4a6..ad1f0cea 100644 --- a/src/main/java/org/jblas/SimpleBlas.java +++ b/src/main/java/org/jblas/SimpleBlas.java @@ -379,6 +379,25 @@ public static int geev(char jobvl, char jobvr, DoubleMatrix A, return info; } + public static int cgeev(char jobvl, char jobvr, ComplexDoubleMatrix A, ComplexDoubleMatrix W, + ComplexDoubleMatrix VL, ComplexDoubleMatrix VR) { + double[] rwork = new double[A.rows*2]; + int info = NativeBlas.zgeev(jobvl, jobvr, A.rows, A.data, 0, A.rows, W.data, 0, VL.data, 0, VL.rows, + VR.data, 0, VR.rows, rwork, 0); + if (info > 0) + throw new LapackConvergenceException("CGEEV", "First " + info + " eigenvalues have not converged."); + return info; + } + + public static int cheev(char jobz, char uplo, ComplexDoubleMatrix A, DoubleMatrix W) { + double[] rwork = new double[3*A.rows-2]; + int info = NativeBlas.zheev(jobz, uplo, A.rows, A.data, 0, A.rows, W.data, 0, rwork, 0); + if (info > 0) + throw new LapackConvergenceException("CGEEV", "Eigenvalues could not be computed " + info + + " off-diagonal elements did not converge"); + return info; + } + public static int sygvd(int itype, char jobz, char uplo, DoubleMatrix A, DoubleMatrix B, DoubleMatrix W) { int info = NativeBlas.dsygvd(itype, jobz, uplo, A.rows, A.data, 0, A.rows, B.data, 0, B.rows, W.data, 0); if (info == 0) @@ -800,6 +819,25 @@ public static int geev(char jobvl, char jobvr, FloatMatrix A, return info; } + public static int cgeev(char jobvl, char jobvr, ComplexFloatMatrix A, ComplexFloatMatrix W, + ComplexFloatMatrix VL, ComplexFloatMatrix VR) { + float[] rwork = new float[A.rows*2]; + int info = NativeBlas.cgeev(jobvl, jobvr, A.rows, A.data, 0, A.rows, W.data, 0, VL.data, 0, VL.rows, + VR.data, 0, VR.rows, rwork, 0); + if (info > 0) + throw new LapackConvergenceException("CGEEV", "First " + info + " eigenvalues have not converged."); + return info; + } + + public static int cheev(char jobz, char uplo, ComplexFloatMatrix A, FloatMatrix W) { + float[] rwork = new float[3*A.rows-2]; + int info = NativeBlas.cheev(jobz, uplo, A.rows, A.data, 0, A.rows, W.data, 0, rwork, 0); + if (info > 0) + throw new LapackConvergenceException("CGEEV", "Eigenvalues could not be computed " + info + + " off-diagonal elements did not converge"); + return info; + } + public static int sygvd(int itype, char jobz, char uplo, FloatMatrix A, FloatMatrix B, FloatMatrix W) { int info = NativeBlas.ssygvd(itype, jobz, uplo, A.rows, A.data, 0, A.rows, B.data, 0, B.rows, W.data, 0); if (info == 0) diff --git a/src/test/java/org/jblas/TestEigen.java b/src/test/java/org/jblas/TestEigen.java index 95bf9281..77f9764c 100644 --- a/src/test/java/org/jblas/TestEigen.java +++ b/src/test/java/org/jblas/TestEigen.java @@ -96,4 +96,37 @@ public void testSymmetricEigenvalues() { assertEquals(0.0, eigenvalues.sub(L).normmax(), eps); } + + @Test + public void testComplexEigenvalues() { + ComplexDoubleMatrix A = new ComplexDoubleMatrix( + new DoubleMatrix(new double[][]{ {1.0, 0}, {0, 1.0}}), + new DoubleMatrix(new double[][]{ {0, 1}, {-1, 0}}) + ); + + ComplexDoubleMatrix E = Eigen.eigenvalues(A); + ComplexDoubleMatrix[] EV = Eigen.eigenvectors(A); + ComplexDoubleMatrix X = EV[0]; + ComplexDoubleMatrix L = EV[1]; + + assertEquals(2.0, E.get(0).real(), eps); + assertEquals(0.0, E.get(1).real(), eps); + assertEquals(0.0, A.mmul(X).sub(X.mmul(L)).norm2(), eps); + } + + @Test + public void testComplexHermitianEigenvalues() { + ComplexDoubleMatrix A = new ComplexDoubleMatrix( + new DoubleMatrix(new double[][]{{2.0, 2.0, 4.0}, {2.0, 3.0, 0.0}, {4.0, 0.0, 1.0}}), + new DoubleMatrix(new double[][]{{0.0, 1.0, 0.0}, {-1.0, 0.0, 1.0}, {0.0, -1.0, 0.0}}) + ); + + DoubleMatrix E = Eigen.hermitianEigenvalues(A); + ComplexDoubleMatrix[] EV = Eigen.hermitianEigenvectors(A); + ComplexDoubleMatrix X = EV[0]; + ComplexDoubleMatrix L = EV[1]; + + assertEquals(0.0, L.diag().sub(E.toComplex()).norm2(), eps); + assertEquals(0.0, A.mmul(X).sub(X.mmul(L)).norm2(), eps); + } }