diff --git a/src/QMCWaveFunctions/RotatedSPOsT.cpp b/src/QMCWaveFunctions/RotatedSPOsT.cpp index 128bca9798..dabdc282a9 100644 --- a/src/QMCWaveFunctions/RotatedSPOsT.cpp +++ b/src/QMCWaveFunctions/RotatedSPOsT.cpp @@ -1688,6 +1688,152 @@ RotatedSPOsT::makeClone() const return myclone; } +template +void +RotatedSPOsT::mw_evaluateDetRatios( + const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& vp_list, + const RefVector& psi_list, + const std::vector& invRow_ptr_list, + std::vector>& ratios_list) const +{ + auto phi_list = extractPhiRefList(spo_list); + auto& leader = phi_list.getLeader(); + leader.mw_evaluateDetRatios( + phi_list, vp_list, psi_list, invRow_ptr_list, ratios_list); +} + +template +void +RotatedSPOsT::mw_evaluateValue( + const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const RefVector& psi_v_list) const +{ + auto phi_list = extractPhiRefList(spo_list); + auto& leader = phi_list.getLeader(); + leader.mw_evaluateValue(phi_list, P_list, iat, psi_v_list); +} + +template +void +RotatedSPOsT::mw_evaluateVGL(const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const RefVector& psi_v_list, + const RefVector& dpsi_v_list, + const RefVector& d2psi_v_list) const +{ + auto phi_list = extractPhiRefList(spo_list); + auto& leader = phi_list.getLeader(); + leader.mw_evaluateVGL( + phi_list, P_list, iat, psi_v_list, dpsi_v_list, d2psi_v_list); +} + +template +void +RotatedSPOsT::mw_evaluateVGLWithSpin( + const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const RefVector& psi_v_list, + const RefVector& dpsi_v_list, + const RefVector& d2psi_v_list, + OffloadMatrix& mw_dspin) const +{ + auto phi_list = extractPhiRefList(spo_list); + auto& leader = phi_list.getLeader(); + leader.mw_evaluateVGLWithSpin( + phi_list, P_list, iat, psi_v_list, dpsi_v_list, d2psi_v_list, mw_dspin); +} + +template +void +RotatedSPOsT::mw_evaluateVGLandDetRatioGrads( + const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const std::vector& invRow_ptr_list, + OffloadMWVGLArray& phi_vgl_v, std::vector& ratios, + std::vector& grads) const +{ + auto phi_list = extractPhiRefList(spo_list); + auto& leader = phi_list.getLeader(); + leader.mw_evaluateVGLandDetRatioGrads( + phi_list, P_list, iat, invRow_ptr_list, phi_vgl_v, ratios, grads); +} + +template +void +RotatedSPOsT::mw_evaluateVGLandDetRatioGradsWithSpin( + const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const std::vector& invRow_ptr_list, + OffloadMWVGLArray& phi_vgl_v, std::vector& ratios, + std::vector& grads, std::vector& spingrads) const +{ + auto phi_list = extractPhiRefList(spo_list); + auto& leader = phi_list.getLeader(); + leader.mw_evaluateVGLandDetRatioGradsWithSpin(phi_list, P_list, iat, + invRow_ptr_list, phi_vgl_v, ratios, grads, spingrads); +} + +template +void +RotatedSPOsT::mw_evaluate_notranspose( + const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int first, int last, + const RefVector& logdet_list, + const RefVector& dlogdet_list, + const RefVector& d2logdet_list) const +{ + auto phi_list = extractPhiRefList(spo_list); + auto& leader = phi_list.getLeader(); + leader.mw_evaluate_notranspose(phi_list, P_list, first, last, logdet_list, + dlogdet_list, d2logdet_list); +} + +template +void +RotatedSPOsT::createResource(ResourceCollection& collection) const +{ + Phi->createResource(collection); +} + +template +void +RotatedSPOsT::acquireResource(ResourceCollection& collection, + const RefVectorWithLeader>& spo_list) const +{ + auto phi_list = extractPhiRefList(spo_list); + auto& leader = phi_list.getLeader(); + leader.acquireResource(collection, phi_list); +} + +template +void +RotatedSPOsT::releaseResource(ResourceCollection& collection, + const RefVectorWithLeader>& spo_list) const +{ + auto phi_list = extractPhiRefList(spo_list); + auto& leader = phi_list.getLeader(); + leader.releaseResource(collection, phi_list); +} + +template +RefVectorWithLeader> +RotatedSPOsT::extractPhiRefList( + const RefVectorWithLeader>& spo_list) +{ + auto& spo_leader = spo_list.template getCastedLeader(); + const auto nw = spo_list.size(); + RefVectorWithLeader> phi_list(*spo_leader.Phi); + phi_list.reserve(nw); + for (int iw = 0; iw < nw; iw++) { + RotatedSPOsT& rot = + spo_list.template getCastedElement(iw); + phi_list.emplace_back(*rot.Phi); + } + return phi_list; +} + // Class concrete types from ValueType template class RotatedSPOsT; template class RotatedSPOsT; diff --git a/src/QMCWaveFunctions/RotatedSPOsT.h b/src/QMCWaveFunctions/RotatedSPOsT.h index 971d2528b3..fa4778a6f4 100644 --- a/src/QMCWaveFunctions/RotatedSPOsT.h +++ b/src/QMCWaveFunctions/RotatedSPOsT.h @@ -40,6 +40,8 @@ class RotatedSPOsT : public SPOSetT, public OptimizableObjectT using IndexType = typename SPOSetT::IndexType; using RealType = typename SPOSetT::RealType; using ValueType = typename SPOSetT::ValueType; + using GradType = typename SPOSetT::GradType; + using ComplexType = typename SPOSetT::ComplexType; using FullRealType = typename SPOSetT::FullRealType; using ValueVector = typename SPOSetT::ValueVector; using ValueMatrix = typename SPOSetT::ValueMatrix; @@ -49,6 +51,9 @@ class RotatedSPOsT : public SPOSetT, public OptimizableObjectT using HessMatrix = typename SPOSetT::HessMatrix; using GGGVector = typename SPOSetT::GGGVector; using GGGMatrix = typename SPOSetT::GGGMatrix; + using OffloadMWVGLArray = typename SPOSetT::OffloadMWVGLArray; + template + using OffloadMatrix = Matrix>; // constructor RotatedSPOsT( @@ -399,6 +404,68 @@ class RotatedSPOsT : public SPOSetT, public OptimizableObjectT use_global_rot_ = use_global_rotation; } + void + mw_evaluateDetRatios(const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& vp_list, + const RefVector& psi_list, + const std::vector& invRow_ptr_list, + std::vector>& ratios_list) const override; + + void + mw_evaluateValue(const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const RefVector& psi_v_list) const override; + + void + mw_evaluateVGL(const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const RefVector& psi_v_list, + const RefVector& dpsi_v_list, + const RefVector& d2psi_v_list) const override; + + void + mw_evaluateVGLWithSpin(const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const RefVector& psi_v_list, + const RefVector& dpsi_v_list, + const RefVector& d2psi_v_list, + OffloadMatrix& mw_dspin) const override; + + void + mw_evaluateVGLandDetRatioGrads( + const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const std::vector& invRow_ptr_list, + OffloadMWVGLArray& phi_vgl_v, std::vector& ratios, + std::vector& grads) const override; + + void + mw_evaluateVGLandDetRatioGradsWithSpin( + const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int iat, + const std::vector& invRow_ptr_list, + OffloadMWVGLArray& phi_vgl_v, std::vector& ratios, + std::vector& grads, + std::vector& spingrads) const override; + + void + mw_evaluate_notranspose(const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& P_list, int first, int last, + const RefVector& logdet_list, + const RefVector& dlogdet_list, + const RefVector& d2logdet_list) const override; + + void + createResource(ResourceCollection& collection) const override; + + void + acquireResource(ResourceCollection& collection, + const RefVectorWithLeader>& spo_list) const override; + + void + releaseResource(ResourceCollection& collection, + const RefVectorWithLeader>& spo_list) const override; + private: /// true if SPO parameters (orbital rotation parameters) have been supplied /// by input @@ -415,6 +482,9 @@ class RotatedSPOsT : public SPOSetT, public OptimizableObjectT /// Use global rotation or history list bool use_global_rot_ = true; + static RefVectorWithLeader> + extractPhiRefList(const RefVectorWithLeader>& spo_list); + friend OptVariablesType& testing::getMyVarsFull(RotatedSPOsT& rot); friend OptVariablesType&