Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
RMeli committed Jun 17, 2024
1 parent 962a037 commit 4c5ebe1
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 135 deletions.
60 changes: 37 additions & 23 deletions include/dlaf/eigensolver/gen_eigensolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

/// @file

#include "gen_eigensolver/api.h"
#include <utility>

#include <blas.hh>
Expand All @@ -21,13 +20,17 @@
#include <dlaf/types.h>
#include <dlaf/util_matrix.h>

#include "gen_eigensolver/api.h"

namespace dlaf {

namespace eigensolver::internal{
namespace eigensolver::internal {

template <Backend B, Device D, class T>
void hermitian_generalized_eigensolver_helper(blas::Uplo uplo, Matrix<T, D>& mat_a, Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues, Matrix<T, D>& eigenvectors, const Factorization factorization) {
Matrix<BaseType<T>, D>& eigenvalues,
Matrix<T, D>& eigenvectors,
const Factorization factorization) {
DLAF_ASSERT(matrix::local_matrix(mat_a), mat_a);
DLAF_ASSERT(matrix::local_matrix(mat_b), mat_b);
DLAF_ASSERT(matrix::local_matrix(eigenvalues), eigenvalues);
Expand All @@ -50,13 +53,14 @@ void hermitian_generalized_eigensolver_helper(blas::Uplo uplo, Matrix<T, D>& mat
DLAF_ASSERT(matrix::single_tile_per_block(eigenvalues), eigenvalues);
DLAF_ASSERT(matrix::single_tile_per_block(eigenvectors), eigenvectors);

eigensolver::internal::GenEigensolver<B, D, T>::call(uplo, mat_a, mat_b, eigenvalues, eigenvectors, factorization);
eigensolver::internal::GenEigensolver<B, D, T>::call(uplo, mat_a, mat_b, eigenvalues, eigenvectors,
factorization);
}

template <Backend B, Device D, class T>
void hermitian_generalized_eigensolver_helper(comm::CommunicatorGrid& grid, blas::Uplo uplo,
Matrix<T, D>& mat_a, Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues, Matrix<T, D>& eigenvectors, const Factorization factorization) {
void hermitian_generalized_eigensolver_helper(
comm::CommunicatorGrid& grid, blas::Uplo uplo, Matrix<T, D>& mat_a, Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues, Matrix<T, D>& eigenvectors, const Factorization factorization) {
DLAF_ASSERT(matrix::equal_process_grid(mat_a, grid), mat_a, grid);
DLAF_ASSERT(matrix::equal_process_grid(mat_b, grid), mat_b, grid);
DLAF_ASSERT(matrix::local_matrix(eigenvalues), eigenvalues);
Expand All @@ -82,7 +86,6 @@ void hermitian_generalized_eigensolver_helper(comm::CommunicatorGrid& grid, blas
eigensolver::internal::GenEigensolver<B, D, T>::call(grid, uplo, mat_a, mat_b, eigenvalues,
eigenvectors, factorization);
}


}

Expand Down Expand Up @@ -127,7 +130,9 @@ void hermitian_generalized_eigensolver_helper(comm::CommunicatorGrid& grid, blas
template <Backend B, Device D, class T>
void hermitian_generalized_eigensolver(blas::Uplo uplo, Matrix<T, D>& mat_a, Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues, Matrix<T, D>& eigenvectors) {
eigensolver::internal::hermitian_generalized_eigensolver_helper<B, D, T>(uplo, mat_a, mat_b, eigenvalues, eigenvectors, eigensolver::internal::Factorization::do_factorization);
eigensolver::internal::hermitian_generalized_eigensolver_helper<B, D, T>(
uplo, mat_a, mat_b, eigenvalues, eigenvectors,
eigensolver::internal::Factorization::do_factorization);
}

/// Generalized Eigensolver.
Expand Down Expand Up @@ -218,9 +223,13 @@ EigensolverResult<T, D> hermitian_generalized_eigensolver(blas::Uplo uplo, Matri
/// @pre @p eigenvectors has blocksize (NB x NB)
/// @pre @p eigenvectors has tilesize (NB x NB)
template <Backend B, Device D, class T>
void hermitian_generalized_eigensolver_factorized(blas::Uplo uplo, Matrix<T, D>& mat_a, Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues, Matrix<T, D>& eigenvectors) {
eigensolver::internal::hermitian_generalized_eigensolver_helper<B, D, T>(uplo, mat_a, mat_b, eigenvalues, eigenvectors, eigensolver::internal::Factorization::already_factorized);
void hermitian_generalized_eigensolver_factorized(blas::Uplo uplo, Matrix<T, D>& mat_a,
Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues,
Matrix<T, D>& eigenvectors) {
eigensolver::internal::hermitian_generalized_eigensolver_helper<B, D, T>(
uplo, mat_a, mat_b, eigenvalues, eigenvectors,
eigensolver::internal::Factorization::already_factorized);
}

/// Generalized Eigensolver.
Expand Down Expand Up @@ -250,8 +259,9 @@ void hermitian_generalized_eigensolver_factorized(blas::Uplo uplo, Matrix<T, D>&
/// @pre @p mat_b has tilesize (NB x NB)
/// @pre @p mat_b is the result of a Cholesky factorization
template <Backend B, Device D, class T>
EigensolverResult<T, D> hermitian_generalized_eigensolver_factorized(blas::Uplo uplo, Matrix<T, D>& mat_a,
Matrix<T, D>& mat_b) {
EigensolverResult<T, D> hermitian_generalized_eigensolver_factorized(blas::Uplo uplo,
Matrix<T, D>& mat_a,
Matrix<T, D>& mat_b) {
DLAF_ASSERT(matrix::local_matrix(mat_a), mat_a);
DLAF_ASSERT(matrix::local_matrix(mat_b), mat_b);
DLAF_ASSERT(matrix::square_size(mat_a), mat_a);
Expand Down Expand Up @@ -315,8 +325,9 @@ template <Backend B, Device D, class T>
void hermitian_generalized_eigensolver(comm::CommunicatorGrid& grid, blas::Uplo uplo,
Matrix<T, D>& mat_a, Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues, Matrix<T, D>& eigenvectors) {
eigensolver::internal::hermitian_generalized_eigensolver_helper<B, D, T>(grid, uplo, mat_a, mat_b, eigenvalues,
eigenvectors, eigensolver::internal::Factorization::do_factorization);
eigensolver::internal::hermitian_generalized_eigensolver_helper<B, D, T>(
grid, uplo, mat_a, mat_b, eigenvalues, eigenvectors,
eigensolver::internal::Factorization::do_factorization);
}

/// Generalized Eigensolver.
Expand Down Expand Up @@ -410,10 +421,12 @@ EigensolverResult<T, D> hermitian_generalized_eigensolver(comm::CommunicatorGrid
/// @pre @p eigenvectors has tilesize (NB x NB)
template <Backend B, Device D, class T>
void hermitian_generalized_eigensolver_factorized(comm::CommunicatorGrid& grid, blas::Uplo uplo,
Matrix<T, D>& mat_a, Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues, Matrix<T, D>& eigenvectors) {
eigensolver::internal::hermitian_generalized_eigensolver_helper<B, D, T>(grid, uplo, mat_a, mat_b, eigenvalues,
eigenvectors, eigensolver::internal::Factorization::already_factorized);
Matrix<T, D>& mat_a, Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues,
Matrix<T, D>& eigenvectors) {
eigensolver::internal::hermitian_generalized_eigensolver_helper<B, D, T>(
grid, uplo, mat_a, mat_b, eigenvalues, eigenvectors,
eigensolver::internal::Factorization::already_factorized);
}

/// Generalized Eigensolver.
Expand Down Expand Up @@ -444,8 +457,8 @@ void hermitian_generalized_eigensolver_factorized(comm::CommunicatorGrid& grid,
/// @pre @p mat_b has tilesize (NB x NB)
/// @pre @p mat_b is the result of a Cholesky factorization
template <Backend B, Device D, class T>
EigensolverResult<T, D> hermitian_generalized_eigensolver_factorized(comm::CommunicatorGrid& grid, blas::Uplo uplo,
Matrix<T, D>& mat_a, Matrix<T, D>& mat_b) {
EigensolverResult<T, D> hermitian_generalized_eigensolver_factorized(
comm::CommunicatorGrid& grid, blas::Uplo uplo, Matrix<T, D>& mat_a, Matrix<T, D>& mat_b) {
DLAF_ASSERT(matrix::equal_process_grid(mat_a, grid), mat_a, grid);
DLAF_ASSERT(matrix::equal_process_grid(mat_b, grid), mat_b, grid);
DLAF_ASSERT(matrix::square_size(mat_a), mat_a);
Expand All @@ -461,7 +474,8 @@ EigensolverResult<T, D> hermitian_generalized_eigensolver_factorized(comm::Commu
TileElementSize(mat_a.blockSize().rows(), 1));
matrix::Matrix<T, D> eigenvectors(GlobalElementSize(size, size), mat_a.blockSize(), grid);

hermitian_generalized_eigensolver_factorized<B, D, T>(grid, uplo, mat_a, mat_b, eigenvalues, eigenvectors);
hermitian_generalized_eigensolver_factorized<B, D, T>(grid, uplo, mat_a, mat_b, eigenvalues,
eigenvectors);

return {std::move(eigenvalues), std::move(eigenvectors)};
}
Expand Down
5 changes: 3 additions & 2 deletions include/dlaf/eigensolver/gen_eigensolver/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

namespace dlaf::eigensolver::internal {

enum class Factorization {do_factorization, already_factorized};
enum class Factorization { do_factorization, already_factorized };

template <Backend backend, Device device, class T>
struct GenEigensolver {
static void call(blas::Uplo uplo, Matrix<T, device>& mat_a, Matrix<T, device>& mat_b,
Matrix<BaseType<T>, device>& eigenvalues, Matrix<T, device>& eigenvectors, const Factorization factorization);
Matrix<BaseType<T>, device>& eigenvalues, Matrix<T, device>& eigenvectors,
const Factorization factorization);
static void call(comm::CommunicatorGrid& grid, blas::Uplo uplo, Matrix<T, device>& mat_a,
Matrix<T, device>& mat_b, Matrix<BaseType<T>, device>& eigenvalues,
Matrix<T, device>& eigenvectors, const Factorization factorizatio);
Expand Down
10 changes: 6 additions & 4 deletions include/dlaf/eigensolver/gen_eigensolver/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
//
#pragma once

#include "api.h"
#include <dlaf/eigensolver/eigensolver.h>
#include <dlaf/eigensolver/gen_eigensolver/api.h>
#include <dlaf/eigensolver/gen_to_std.h>
Expand All @@ -18,12 +17,15 @@
#include <dlaf/solver/triangular.h>
#include <dlaf/util_matrix.h>

#include "api.h"

namespace dlaf::eigensolver::internal {

template <Backend B, Device D, class T>
void GenEigensolver<B, D, T>::call(blas::Uplo uplo, Matrix<T, D>& mat_a, Matrix<T, D>& mat_b,
Matrix<BaseType<T>, D>& eigenvalues, Matrix<T, D>& eigenvectors, const Factorization factorization) {
if(factorization == Factorization::do_factorization){
Matrix<BaseType<T>, D>& eigenvalues, Matrix<T, D>& eigenvectors,
const Factorization factorization) {
if (factorization == Factorization::do_factorization) {
cholesky_factorization<B>(uplo, mat_b);
}
generalized_to_standard<B>(uplo, mat_a, mat_b);
Expand All @@ -38,7 +40,7 @@ template <Backend B, Device D, class T>
void GenEigensolver<B, D, T>::call(comm::CommunicatorGrid& grid, blas::Uplo uplo, Matrix<T, D>& mat_a,
Matrix<T, D>& mat_b, Matrix<BaseType<T>, D>& eigenvalues,
Matrix<T, D>& eigenvectors, const Factorization factorization) {
if(factorization == Factorization::do_factorization){
if (factorization == Factorization::do_factorization) {
cholesky_factorization<B>(grid, uplo, mat_b);
}
generalized_to_standard<B>(grid, uplo, mat_a, mat_b);
Expand Down
85 changes: 43 additions & 42 deletions src/c_api/eigensolver/gen_eigensolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,39 +52,38 @@ int dlaf_hermitian_generalized_eigensolver_z(const int dlaf_context, const char
dlaf_descb, w, z, dlaf_descz);
}

int dlaf_symmetric_generalized_eigensolver_factorized_s(const int dlaf_context, const char uplo, float* a,
const struct DLAF_descriptor dlaf_desca, float* b,
const struct DLAF_descriptor dlaf_descb, float* w, float* z,
const struct DLAF_descriptor dlaf_descz) noexcept {
return hermitian_generalized_eigensolver_factorized<float>(dlaf_context, uplo, a, dlaf_desca, b, dlaf_descb, w, z,
dlaf_descz);
int dlaf_symmetric_generalized_eigensolver_factorized_s(
const int dlaf_context, const char uplo, float* a, const struct DLAF_descriptor dlaf_desca, float* b,
const struct DLAF_descriptor dlaf_descb, float* w, float* z,
const struct DLAF_descriptor dlaf_descz) noexcept {
return hermitian_generalized_eigensolver_factorized<float>(dlaf_context, uplo, a, dlaf_desca, b,
dlaf_descb, w, z, dlaf_descz);
}

int dlaf_symmetric_generalized_eigensolver_factorized_d(const int dlaf_context, const char uplo, double* a,
const struct DLAF_descriptor dlaf_desca, double* b,
const struct DLAF_descriptor dlaf_descb, double* w,
double* z,
const struct DLAF_descriptor dlaf_descz) noexcept {
return hermitian_generalized_eigensolver_factorized<double>(dlaf_context, uplo, a, dlaf_desca, b, dlaf_descb, w,
z, dlaf_descz);
int dlaf_symmetric_generalized_eigensolver_factorized_d(
const int dlaf_context, const char uplo, double* a, const struct DLAF_descriptor dlaf_desca,
double* b, const struct DLAF_descriptor dlaf_descb, double* w, double* z,
const struct DLAF_descriptor dlaf_descz) noexcept {
return hermitian_generalized_eigensolver_factorized<double>(dlaf_context, uplo, a, dlaf_desca, b,
dlaf_descb, w, z, dlaf_descz);
}

int dlaf_hermitian_generalized_eigensolver_factorized_c(const int dlaf_context, const char uplo, dlaf_complex_c* a,
const struct DLAF_descriptor dlaf_desca, dlaf_complex_c* b,
const struct DLAF_descriptor dlaf_descb, float* w,
dlaf_complex_c* z,
const struct DLAF_descriptor dlaf_descz) noexcept {
return hermitian_generalized_eigensolver_factorized<std::complex<float>>(dlaf_context, uplo, a, dlaf_desca, b,
dlaf_descb, w, z, dlaf_descz);
int dlaf_hermitian_generalized_eigensolver_factorized_c(
const int dlaf_context, const char uplo, dlaf_complex_c* a, const struct DLAF_descriptor dlaf_desca,
dlaf_complex_c* b, const struct DLAF_descriptor dlaf_descb, float* w, dlaf_complex_c* z,
const struct DLAF_descriptor dlaf_descz) noexcept {
return hermitian_generalized_eigensolver_factorized<std::complex<float>>(dlaf_context, uplo, a,
dlaf_desca, b, dlaf_descb, w,
z, dlaf_descz);
}

int dlaf_hermitian_generalized_eigensolver_factorized_z(const int dlaf_context, const char uplo, dlaf_complex_z* a,
const struct DLAF_descriptor dlaf_desca, dlaf_complex_z* b,
const struct DLAF_descriptor dlaf_descb, double* w,
dlaf_complex_z* z,
const struct DLAF_descriptor dlaf_descz) noexcept {
return hermitian_generalized_eigensolver_factorized<std::complex<double>>(dlaf_context, uplo, a, dlaf_desca, b,
dlaf_descb, w, z, dlaf_descz);
int dlaf_hermitian_generalized_eigensolver_factorized_z(
const int dlaf_context, const char uplo, dlaf_complex_z* a, const struct DLAF_descriptor dlaf_desca,
dlaf_complex_z* b, const struct DLAF_descriptor dlaf_descb, double* w, dlaf_complex_z* z,
const struct DLAF_descriptor dlaf_descz) noexcept {
return hermitian_generalized_eigensolver_factorized<std::complex<double>>(dlaf_context, uplo, a,
dlaf_desca, b, dlaf_descb, w,
z, dlaf_descz);
}

#ifdef DLAF_WITH_SCALAPACK
Expand Down Expand Up @@ -116,32 +115,34 @@ void dlaf_pzhegvx(const char uplo, const int m, dlaf_complex_z* a, const int ia,
pxhegvx<std::complex<double>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, *info);
}

void dlaf_pssygvx_factorized(const char uplo, const int m, float* a, const int ia, const int ja, const int desca[9],
float* b, const int ib, const int jb, const int descb[9], float* w, float* z,
const int iz, const int jz, const int descz[9], int* info) noexcept {
void dlaf_pssygvx_factorized(const char uplo, const int m, float* a, const int ia, const int ja,
const int desca[9], float* b, const int ib, const int jb,
const int descb[9], float* w, float* z, const int iz, const int jz,
const int descz[9], int* info) noexcept {
pxhegvx_factorized<float>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, *info);
}

void dlaf_pdsygvx_factorized(const char uplo, const int m, double* a, const int ia, const int ja,
const int desca[9], double* b, const int ib, const int jb, const int descb[9],
double* w, double* z, const int iz, const int jz, const int descz[9],
int* info) noexcept {
const int desca[9], double* b, const int ib, const int jb,
const int descb[9], double* w, double* z, const int iz, const int jz,
const int descz[9], int* info) noexcept {
pxhegvx_factorized<double>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, *info);
}

void dlaf_pchegvx_factorized(const char uplo, const int m, dlaf_complex_c* a, const int ia, const int ja,
const int desca[9], dlaf_complex_c* b, const int ib, const int jb, const int descb[9],
float* w, dlaf_complex_c* z, const int iz, const int jz, const int descz[9],
int* info) noexcept {
pxhegvx_factorized<std::complex<float>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, *info);
const int desca[9], dlaf_complex_c* b, const int ib, const int jb,
const int descb[9], float* w, dlaf_complex_c* z, const int iz, const int jz,
const int descz[9], int* info) noexcept {
pxhegvx_factorized<std::complex<float>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz,
descz, *info);
}

void dlaf_pzhegvx_factorized(const char uplo, const int m, dlaf_complex_z* a, const int ia, const int ja,
const int desca[9], dlaf_complex_z* b, const int ib, const int jb, const int descb[9],
double* w, dlaf_complex_z* z, const int iz, const int jz, const int descz[9],
int* info) noexcept {
pxhegvx_factorized<std::complex<double>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, *info);
const int desca[9], dlaf_complex_z* b, const int ib, const int jb,
const int descb[9], double* w, dlaf_complex_z* z, const int iz,
const int jz, const int descz[9], int* info) noexcept {
pxhegvx_factorized<std::complex<double>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz,
descz, *info);
}


#endif
Loading

0 comments on commit 4c5ebe1

Please sign in to comment.