Skip to content

Commit

Permalink
add phi_v to shared resource
Browse files Browse the repository at this point in the history
  • Loading branch information
kgasperich committed Jun 22, 2023
1 parent b61f017 commit 95cc63c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
17 changes: 11 additions & 6 deletions src/QMCWaveFunctions/LCAO/LCAOrbitalSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct LCAOrbitalSet::LCAOMultiWalkerMem : public Resource

OffloadMWVGLArray phi_vgl_v;
// [5][NW][NumAO]
OffloadMWVArray phi_v;
OffloadMWVGLArray basis_mw;
};

Expand Down Expand Up @@ -453,8 +454,8 @@ void LCAOrbitalSet::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSet>& sp
{
ScopedTimer local(mo_timer_);
ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize);
// TODO: make class for general blas interface in Platforms
// have instance of that class as member of LCAOrbitalSet, call gemm through that
// TODO: make class for general blas interface in Platforms
// have instance of that class as member of LCAOrbitalSet, call gemm through that
BLAS::gemm('T', 'N',
requested_orb_size, // MOs
spo_list.size() * DIM_VGL, // walkers * DIM_VGL
Expand All @@ -470,7 +471,9 @@ void LCAOrbitalSet::mw_evaluateValue(const RefVectorWithLeader<SPOSet>& spo_list
int iat,
const RefVector<ValueVector>& psi_v_list) const
{
OffloadMWVArray phi_v;
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
auto& phi_v = spo_leader.mw_mem_handle_.getResource().phi_v;
phi_v.resize(spo_list.size(), OrbitalSetSize);
mw_evaluateValueImplGEMM(spo_list, P_list, iat, phi_v);

Expand All @@ -486,8 +489,10 @@ void LCAOrbitalSet::mw_evaluateValueImplGEMM(const RefVectorWithLeader<SPOSet>&
int iat,
OffloadMWVArray& psi_v) const
{
const size_t nw = spo_list.size();
OffloadMWVArray phi_v;
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
const size_t nw = spo_list.size();
auto& phi_v = spo_leader.mw_mem_handle_.getResource().phi_v;
phi_v.resize(nw, BasisSetSize);

myBasisSet->mw_evaluateValue(P_list, iat, phi_v);
Expand Down Expand Up @@ -515,7 +520,7 @@ void LCAOrbitalSet::mw_evaluateDetRatios(const RefVectorWithLeader<SPOSet>& spo_
const std::vector<const ValueType*>& invRow_ptr_list,
std::vector<std::vector<ValueType>>& ratios_list) const
{
size_t nw = spo_list.size();
const size_t nw = spo_list.size();
for (size_t iw = 0; iw < nw; iw++)
{
for (size_t iat = 0; iat < vp_list[iw].getTotalNum(); iat++)
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/LCAO/SoaLocalizedBasisSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ void SoaLocalizedBasisSet<COT, ORBT>::mw_evaluateValue(const RefVectorWithLeader
OffloadMWVArray& v)
{
for (size_t iw = 0; iw < P_list.size(); iw++)
evaluateV(P_list[iw], iat, v.data_at(iw,0));
evaluateV(P_list[iw], iat, v.data_at(iw, 0));
}

template<class COT, typename ORBT>
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/tests/test_MO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ void test_EtOH_mw(bool transform)
size_t nw = psi_list.size();
SPOSet::ValueVector psi_v_1(n_mo);
SPOSet::ValueVector psi_v_2(n_mo);
RefVector<SPOSet::ValueVector> psi_v_list = {psi_v_1, psi_v_2};
RefVector<SPOSet::ValueVector> psi_v_list{psi_v_1, psi_v_2};

ResourceCollection pset_res("test_pset_res");
ResourceCollection spo_res("test_spo_res");
Expand Down

0 comments on commit 95cc63c

Please sign in to comment.