Skip to content

Commit

Permalink
Merge pull request #4611 from kgasperich/offload-lcao-gemm
Browse files Browse the repository at this point in the history
Add LCAO mw_evaluateValue and mw_evaluateDetRatios
  • Loading branch information
prckent authored Jun 24, 2023
2 parents 4d29208 + 9368315 commit a6a3275
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/QMCWaveFunctions/BasisSetBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ struct SoaBasisSetBase
using vghgh_type = VectorSoaContainer<T, 20>;
using ValueType = QMCTraits::ValueType;
using OffloadMWVGLArray = Array<ValueType, 3, OffloadPinnedAllocator<ValueType>>; // [VGL, walker, Orbs]
using OffloadMWVArray = Array<ValueType, 2, OffloadPinnedAllocator<ValueType>>; // [walker, Orbs]

///size of the basis set
int BasisSetSize;
Expand All @@ -150,6 +151,8 @@ struct SoaBasisSetBase
virtual void evaluateVGL(const ParticleSet& P, int iat, vgl_type& vgl) = 0;
//Evaluates value, gradient, and laplacian for electron "iat". places them in a offload array for batched code.
virtual void mw_evaluateVGL(const RefVectorWithLeader<ParticleSet>& P_list, int iat, OffloadMWVGLArray& vgl) = 0;
//Evaluates value for electron "iat". places it in a offload array for batched code.
virtual void mw_evaluateValue(const RefVectorWithLeader<ParticleSet>& P_list, int iat, OffloadMWVArray& v) = 0;
//Evaluates value, gradient, and Hessian for electron "iat". Parks them into a temporary data structure "vgh".
virtual void evaluateVGH(const ParticleSet& P, int iat, vgh_type& vgh) = 0;
//Evaluates value, gradient, and Hessian, and Gradient Hessian for electron "iat". Parks them into a temporary data structure "vghgh".
Expand Down
75 changes: 72 additions & 3 deletions src/QMCWaveFunctions/LCAO/LCAOrbitalSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ struct LCAOrbitalSet::LCAOMultiWalkerMem : public Resource

std::unique_ptr<Resource> makeClone() const override { return std::make_unique<LCAOMultiWalkerMem>(*this); }

OffloadMWVGLArray phi_vgl_v;
// [5][NW][NumAO]
OffloadMWVGLArray basis_mw;
OffloadMWVGLArray phi_vgl_v; // [5][NW][NumMO]
OffloadMWVGLArray basis_mw; // [5][NW][NumAO]
OffloadMWVArray phi_v; // [NW][NumMO]
OffloadMWVArray basis_v_mw; // [NW][NumMO]
};

LCAOrbitalSet::LCAOrbitalSet(const std::string& my_name, std::unique_ptr<basis_type>&& bs)
Expand Down Expand Up @@ -453,6 +454,8 @@ void LCAOrbitalSet::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSet>& sp
{
ScopedTimer local(mo_timer_);
ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize);
// TODO: make class for general blas interface in Platforms
// have instance of that class as member of LCAOrbitalSet, call gemm through that
BLAS::gemm('T', 'N',
requested_orb_size, // MOs
spo_list.size() * DIM_VGL, // walkers * DIM_VGL
Expand All @@ -463,6 +466,72 @@ void LCAOrbitalSet::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSet>& sp
}
}

void LCAOrbitalSet::mw_evaluateValue(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<ParticleSet>& P_list,
int iat,
const RefVector<ValueVector>& psi_v_list) const
{
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
auto& phi_v = spo_leader.mw_mem_handle_.getResource().phi_v;
phi_v.resize(spo_list.size(), OrbitalSetSize);
mw_evaluateValueImplGEMM(spo_list, P_list, iat, phi_v);

const size_t output_size = phi_v.size(1);
const size_t nw = phi_v.size(0);

for (int iw = 0; iw < nw; iw++)
std::copy_n(phi_v.data_at(iw, 0), output_size, psi_v_list[iw].get().data());
}

void LCAOrbitalSet::mw_evaluateValueImplGEMM(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<ParticleSet>& P_list,
int iat,
OffloadMWVArray& phi_v) const
{
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
const size_t nw = spo_list.size();
auto& basis_v_mw = spo_leader.mw_mem_handle_.getResource().basis_v_mw;
basis_v_mw.resize(nw, BasisSetSize);

myBasisSet->mw_evaluateValue(P_list, iat, basis_v_mw);

if (Identity)
{
std::copy_n(basis_v_mw.data_at(0, 0), OrbitalSetSize * nw, phi_v.data_at(0, 0));
}
else
{
const size_t requested_orb_size = phi_v.size(1);
assert(requested_orb_size <= OrbitalSetSize);
ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize);
BLAS::gemm('T', 'N',
requested_orb_size, // MOs
spo_list.size(), // walkers
BasisSetSize, // AOs
1, C_partial_view.data(), BasisSetSize, basis_v_mw.data(), BasisSetSize, 0, phi_v.data(),
requested_orb_size);
}
}

void LCAOrbitalSet::mw_evaluateDetRatios(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
const RefVector<ValueVector>& psi_list,
const std::vector<const ValueType*>& invRow_ptr_list,
std::vector<std::vector<ValueType>>& ratios_list) const
{
const size_t nw = spo_list.size();
for (size_t iw = 0; iw < nw; iw++)
{
for (size_t iat = 0; iat < vp_list[iw].getTotalNum(); iat++)
{
spo_list[iw].evaluateValue(vp_list[iw], iat, psi_list[iw]);
ratios_list[iw][iat] = simd::dot(psi_list[iw].get().data(), invRow_ptr_list[iw], psi_list[iw].get().size());
}
}
}

void LCAOrbitalSet::evaluateDetRatios(const VirtualParticleSet& VP,
ValueVector& psi,
const ValueVector& psiinv,
Expand Down
16 changes: 16 additions & 0 deletions src/QMCWaveFunctions/LCAO/LCAOrbitalSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ struct LCAOrbitalSet : public SPOSet

void evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi) override;

void mw_evaluateValue(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<ParticleSet>& P_list,
int iat,
const RefVector<ValueVector>& psi_v_list) const override;

void mw_evaluateVGL(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<ParticleSet>& P_list,
Expand All @@ -86,6 +90,12 @@ struct LCAOrbitalSet : public SPOSet
const RefVector<GradVector>& dpsi_v_list,
const RefVector<ValueVector>& d2psi_v_list) const override;

void mw_evaluateDetRatios(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
const RefVector<ValueVector>& psi_list,
const std::vector<const ValueType*>& invRow_ptr_list,
std::vector<std::vector<ValueType>>& ratios_list) const override;

void evaluateDetRatios(const VirtualParticleSet& VP,
ValueVector& psi,
const ValueVector& psiinv,
Expand Down Expand Up @@ -294,6 +304,12 @@ struct LCAOrbitalSet : public SPOSet
int iat,
OffloadMWVGLArray& phi_vgl_v) const;

/// packed walker GEMM implementation
void mw_evaluateValueImplGEMM(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<ParticleSet>& P_list,
int iat,
OffloadMWVArray& phi_v) const;

struct LCAOMultiWalkerMem;
ResourceHandle<LCAOMultiWalkerMem> mw_mem_handle_;
/// timer for basis set
Expand Down
9 changes: 9 additions & 0 deletions src/QMCWaveFunctions/LCAO/SoaLocalizedBasisSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ void SoaLocalizedBasisSet<COT, ORBT>::evaluateV(const ParticleSet& P, int iat, O
}
}

template<class COT, typename ORBT>
void SoaLocalizedBasisSet<COT, ORBT>::mw_evaluateValue(const RefVectorWithLeader<ParticleSet>& P_list,
int iat,
OffloadMWVArray& v)
{
for (size_t iw = 0; iw < P_list.size(); iw++)
evaluateV(P_list[iw], iat, v.data_at(iw, 0));
}

template<class COT, typename ORBT>
void SoaLocalizedBasisSet<COT, ORBT>::evaluateGradSourceV(const ParticleSet& P,
int iat,
Expand Down
13 changes: 13 additions & 0 deletions src/QMCWaveFunctions/LCAO/SoaLocalizedBasisSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class SoaLocalizedBasisSet : public SoaBasisSetBase<ORBT>
using vghgh_type = typename BaseType::vghgh_type;
using PosType = typename ParticleSet::PosType;
using OffloadMWVGLArray = Array<ValueType, 3, OffloadPinnedAllocator<ValueType>>; // [VGL, walker, Orbs]
using OffloadMWVArray = Array<ValueType, 2, OffloadPinnedAllocator<ValueType>>; // [walker, Orbs]

using BaseType::BasisSetSize;

Expand Down Expand Up @@ -108,6 +109,18 @@ class SoaLocalizedBasisSet : public SoaBasisSetBase<ORBT>
*/
void evaluateVGL(const ParticleSet& P, int iat, vgl_type& vgl) override;

/** compute V using packed array with all walkers
* @param P_list list of quantum particleset (one for each walker)
* @param iat active particle
* @param v Array(n_walkers, BasisSetSize)
*/
void mw_evaluateValue(const RefVectorWithLeader<ParticleSet>& P_list, int iat, OffloadMWVArray& v) override;

/** compute VGL using packed array with all walkers
* @param P_list list of quantum particleset (one for each walker)
* @param iat active particle
* @param vgl Array(n_walkers, 5, BasisSetSize)
*/
void mw_evaluateVGL(const RefVectorWithLeader<ParticleSet>& P_list, int iat, OffloadMWVGLArray& vgl) override;

/** compute VGH
Expand Down
1 change: 1 addition & 0 deletions src/QMCWaveFunctions/SPOSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class SPOSet : public QMCTraits
using GGGMatrix = OrbitalSetTraits<ValueType>::GradHessMatrix;
using SPOMap = std::map<std::string, const std::unique_ptr<const SPOSet>>;
using OffloadMWVGLArray = Array<ValueType, 3, OffloadPinnedAllocator<ValueType>>; // [VGL, walker, Orbs]
using OffloadMWVArray = Array<ValueType, 2, OffloadPinnedAllocator<ValueType>>; // [walker, Orbs]
template<typename DT>
using OffloadMatrix = Matrix<DT, OffloadPinnedAllocator<DT>>;

Expand Down
11 changes: 11 additions & 0 deletions src/QMCWaveFunctions/tests/test_MO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,11 @@ void test_EtOH_mw(bool transform)
RefVector<SPOSet::GradVector> dpsi_list = {dpsi_1, dpsi_2};
RefVector<SPOSet::ValueVector> d2psi_list = {d2psi_1, d2psi_2};

size_t nw = psi_list.size();
SPOSet::ValueVector psi_v_1(n_mo);
SPOSet::ValueVector psi_v_2(n_mo);
RefVector<SPOSet::ValueVector> psi_v_list{psi_v_1, psi_v_2};

ResourceCollection pset_res("test_pset_res");
ResourceCollection spo_res("test_spo_res");

Expand All @@ -433,9 +438,15 @@ void test_EtOH_mw(bool transform)
ResourceCollectionTeamLock<SPOSet> mw_sposet_lock(spo_res, spo_list);

sposet->mw_evaluateVGL(spo_list, P_list, 0, psi_list, dpsi_list, d2psi_list);
sposet->mw_evaluateValue(spo_list, P_list, 0, psi_v_list);

for (size_t iorb = 0; iorb < n_mo; iorb++)
{
for (size_t iw = 0; iw < nw; iw++)
{
// test values from OffloadMWVArray impl.
CHECK(std::real(psi_v_list[iw].get()[iorb]) == Approx(psi_list[iw].get()[iorb]));
}
CHECK(std::real(psi_list[0].get()[iorb]) == Approx(psiref_0[iorb]));
CHECK(std::real(psi_list[1].get()[iorb]) == Approx(psiref_1[iorb]));
CHECK(std::real(d2psi_list[0].get()[iorb]) == Approx(d2psiref_0[iorb]));
Expand Down

0 comments on commit a6a3275

Please sign in to comment.