diff --git a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp index 30b5ab27d7..5fd485541f 100644 --- a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp +++ b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp @@ -22,6 +22,7 @@ #include "Message/Communicate.h" #include "Numerics/DeterminantOperators.h" #include "CPU/BLAS.hpp" +#include "OMPTarget/ompBLAS.hpp" #include "Numerics/MatrixOperators.h" #include #include @@ -529,12 +530,360 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader& isAccepted) { - // TODO to be expanded to serve offload needs without relying on calling acceptMove and restore - for (int iw = 0; iw < wfc_list.size(); iw++) + const int nw = wfc_list.size(); + assert(isAccepted.size() == nw); + // separate accepted/rejected walker indices + const int n_accepted = std::count(isAccepted.begin(), isAccepted.end(), true); + const int n_rejected = nw - n_accepted; + + //TODO: can put these in some preallocated work space (reserve up to n_walkers) + std::vector idx_Accepted(n_accepted); + std::vector idx_Rejected(n_rejected); + + // create lists of accepted/rejected walker indices + for (int iw = 0, iacc = 0, irej = 0; iw < nw; iw++) if (isAccepted[iw]) - wfc_list[iw].acceptMove(p_list[iw], iat, false); + idx_Accepted[iacc++] = iw; else - wfc_list[iw].restore(iat); + idx_Rejected[irej++] = iw; + + + MultiDiracDeterminant& wfc_leader = wfc_list.getLeader(); + ParticleSet& p_leader = p_list.getLeader(); + const int ndet = wfc_leader.getNumDets(); + const int norb = wfc_leader.NumOrbitals; + const int nel = wfc_leader.NumPtcls; + auto& mw_res = wfc_leader.mw_res_handle_.getResource(); + + const int WorkingIndex = iat - wfc_leader.FirstIndex; + assert(WorkingIndex >= 0 && WorkingIndex < wfc_leader.LastIndex - wfc_leader.FirstIndex); + assert(p_leader.isSpinor() == wfc_leader.is_spinor_); + int handle = 0; + for (auto& iacc : idx_Accepted) + { + auto& wfc = wfc_list.getCastedElement(iacc); + if (wfc.curRatio == ValueType(0)) + { + std::ostringstream msg; + msg << "MultiDiracDeterminant::acceptMove curRatio is " << wfc.curRatio << " for walker " << iacc + << "! Report a bug." << std::endl; + throw std::runtime_error(msg.str()); + } + } + + for (auto& iacc : idx_Accepted) + { + auto& wfc = wfc_list.getCastedElement(iacc); + wfc.log_value_ref_det_ += convertValueToLog(wfc.curRatio); + wfc.curRatio = ValueType(1); + } + + // copy data for accepted walkers + switch (wfc_leader.UpdateMode) + { + case ORB_PBYP_RATIO: + /** + * psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] + * psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim) + * psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] + */ + { + Vector psiMinv_temp_acc_ptr_list(n_accepted); + Vector psiMinv_acc_ptr_list(n_accepted); + + Vector psiV_acc_ptr_list(n_accepted); + Vector TpsiM_col_acc_ptr_list(n_accepted); + Vector psiM_row_acc_ptr_list(n_accepted); + + Vector new_ratios_to_ref_acc_ptr_list(n_accepted); + Vector ratios_to_ref_acc_ptr_list(n_accepted); + + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); + + psiMinv_temp_acc_ptr_list[i] = wfc.psiMinv_temp.data(); + psiMinv_acc_ptr_list[i] = wfc.psiMinv.data(); + + psiV_acc_ptr_list[i] = wfc.psiV.data(); + TpsiM_col_acc_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex; + psiM_row_acc_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb; + + new_ratios_to_ref_acc_ptr_list[i] = wfc.new_ratios_to_ref_.data(); + ratios_to_ref_acc_ptr_list[i] = wfc.ratios_to_ref_.data(); + } + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(nel * nel, psiMinv_temp_acc_ptr_list[i], 1, psiMinv_acc_ptr_list[i], 1); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, TpsiM_col_acc_ptr_list[i], nel); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, psiM_row_acc_ptr_list[i], 1); + BLAS::copy(ndet, new_ratios_to_ref_acc_ptr_list[i], 1, ratios_to_ref_acc_ptr_list[i], 1); + } + // ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_ptr_list.data(), 1, psiMinv_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, TpsiM_col_acc_ptr_list.data(), nel, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_ptr_list.data(), 1, ratios_to_ref_acc_ptr_list.data(), 1, n_accepted); + for (auto& iacc : idx_Accepted) + { + auto& wfc = wfc_list.getCastedElement(iacc); + wfc.ratios_to_ref_.updateTo(); + wfc.TpsiM.updateTo(); + wfc.psiMinv.updateTo(); + wfc.psiM.updateTo(); + // dpsiM not updated on host in this case, but this H2D update is included for consistency with single-walker acceptMove + wfc.dpsiM.updateTo(); + } + } + break; + + case ORB_PBYP_PARTIAL: + /** + * psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] + * psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim) + * new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] + * psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType + * d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * if (is_spinor_) + * dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + */ + { + Vector psiMinv_temp_acc_ptr_list(n_accepted); + Vector psiMinv_acc_ptr_list(n_accepted); + + Vector psiV_acc_ptr_list(n_accepted); + Vector TpsiM_col_acc_ptr_list(n_accepted); + Vector psiM_row_acc_ptr_list(n_accepted); + + Vector new_ratios_to_ref_acc_ptr_list(n_accepted); + Vector ratios_to_ref_acc_ptr_list(n_accepted); + + Vector dpsiV_acc_ptr_list(n_accepted); + Vector dpsiM_row_acc_ptr_list(n_accepted); + Vector d2psiV_acc_ptr_list(n_accepted); + Vector d2psiM_row_acc_ptr_list(n_accepted); + + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); + + psiMinv_temp_acc_ptr_list[i] = wfc.psiMinv_temp.data(); + psiMinv_acc_ptr_list[i] = wfc.psiMinv.data(); + + psiV_acc_ptr_list[i] = wfc.psiV.data(); + TpsiM_col_acc_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex; + psiM_row_acc_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb; + + new_ratios_to_ref_acc_ptr_list[i] = wfc.new_ratios_to_ref_.data(); + ratios_to_ref_acc_ptr_list[i] = wfc.ratios_to_ref_.data(); + + dpsiV_acc_ptr_list[i] = wfc.dpsiV.data()->data(); + dpsiM_row_acc_ptr_list[i] = wfc.dpsiM.data()->data() + WorkingIndex * norb * DIM; + d2psiV_acc_ptr_list[i] = wfc.d2psiV.data(); + d2psiM_row_acc_ptr_list[i] = wfc.d2psiM.data() + WorkingIndex * norb; + } + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(nel * nel, psiMinv_temp_acc_ptr_list[i], 1, psiMinv_acc_ptr_list[i], 1); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, TpsiM_col_acc_ptr_list[i], nel); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, psiM_row_acc_ptr_list[i], 1); + BLAS::copy(norb * DIM, dpsiV_acc_ptr_list[i], 1, dpsiM_row_acc_ptr_list[i], 1); + BLAS::copy(norb, d2psiV_acc_ptr_list[i], 1, d2psiM_row_acc_ptr_list[i], 1); + BLAS::copy(ndet, new_ratios_to_ref_acc_ptr_list[i], 1, ratios_to_ref_acc_ptr_list[i], 1); + } + // ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_ptr_list.data(), 1, psiMinv_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, TpsiM_col_acc_ptr_list.data(), nel, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_ptr_list.data(), 1, dpsiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, d2psiV_acc_ptr_list.data(), 1, d2psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_ptr_list.data(), 1, ratios_to_ref_acc_ptr_list.data(), 1, n_accepted); + + // dspin_psiM/V not on device + if (wfc_leader.is_spinor_) + { + Vector dspin_psiV_acc_ptr_list(n_accepted); + Vector dspin_psiM_row_acc_ptr_list(n_accepted); + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); + dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data(); + dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb; + } + for (int i = 0; i < n_accepted; i++) + BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); + // ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); + } + for (auto& iacc : idx_Accepted) + { + auto& wfc = wfc_list.getCastedElement(iacc); + wfc.ratios_to_ref_.updateTo(); + wfc.TpsiM.updateTo(); + wfc.psiMinv.updateTo(); + wfc.psiM.updateTo(); + wfc.dpsiM.updateTo(); + } + } + break; + + default: + /** + * psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] + * psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim) + * new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] + * new_grads[:,:] -> grads[:,:]; [NumDets,NumPtcls] GradType + * new_lapls[:,:] -> lapls[:,:]; [NumDets,NumPtcls] + * psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType + * d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * if (is_spinor_) + * dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * new_spingrads[:,:] -> spingrads[:,:]; [NumDets,NumPtcls] + */ + { + Vector psiMinv_temp_acc_ptr_list(n_accepted); + Vector psiMinv_acc_ptr_list(n_accepted); + + Vector psiV_acc_ptr_list(n_accepted); + Vector TpsiM_col_acc_ptr_list(n_accepted); + Vector psiM_row_acc_ptr_list(n_accepted); + + Vector new_ratios_to_ref_acc_ptr_list(n_accepted); + Vector ratios_to_ref_acc_ptr_list(n_accepted); + + Vector dpsiV_acc_ptr_list(n_accepted); + Vector dpsiM_row_acc_ptr_list(n_accepted); + Vector d2psiV_acc_ptr_list(n_accepted); + Vector d2psiM_row_acc_ptr_list(n_accepted); + + Vector new_grads_acc_ptr_list(n_accepted); + Vector grads_acc_ptr_list(n_accepted); + Vector new_lapls_acc_ptr_list(n_accepted); + Vector lapls_acc_ptr_list(n_accepted); + + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); + + psiMinv_temp_acc_ptr_list[i] = wfc.psiMinv_temp.data(); + psiMinv_acc_ptr_list[i] = wfc.psiMinv.data(); + + psiV_acc_ptr_list[i] = wfc.psiV.data(); + TpsiM_col_acc_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex; + psiM_row_acc_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb; + + new_ratios_to_ref_acc_ptr_list[i] = wfc.new_ratios_to_ref_.data(); + ratios_to_ref_acc_ptr_list[i] = wfc.ratios_to_ref_.data(); + + dpsiV_acc_ptr_list[i] = wfc.dpsiV.data()->data(); + dpsiM_row_acc_ptr_list[i] = wfc.dpsiM.data()->data() + WorkingIndex * norb * DIM; + d2psiV_acc_ptr_list[i] = wfc.d2psiV.data(); + d2psiM_row_acc_ptr_list[i] = wfc.d2psiM.data() + WorkingIndex * norb; + + new_grads_acc_ptr_list[i] = wfc.new_grads.data()->data(); + grads_acc_ptr_list[i] = wfc.grads.data()->data(); + new_lapls_acc_ptr_list[i] = wfc.new_lapls.data(); + lapls_acc_ptr_list[i] = wfc.lapls.data(); + } + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(nel * nel, psiMinv_temp_acc_ptr_list[i], 1, psiMinv_acc_ptr_list[i], 1); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, TpsiM_col_acc_ptr_list[i], nel); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, psiM_row_acc_ptr_list[i], 1); + BLAS::copy(norb * DIM, dpsiV_acc_ptr_list[i], 1, dpsiM_row_acc_ptr_list[i], 1); + BLAS::copy(norb, d2psiV_acc_ptr_list[i], 1, d2psiM_row_acc_ptr_list[i], 1); + BLAS::copy(ndet, new_ratios_to_ref_acc_ptr_list[i], 1, ratios_to_ref_acc_ptr_list[i], 1); + } + // ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_ptr_list.data(), 1, psiMinv_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, TpsiM_col_acc_ptr_list.data(), nel, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_ptr_list.data(), 1, dpsiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, d2psiV_acc_ptr_list.data(), 1, d2psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_ptr_list.data(), 1, ratios_to_ref_acc_ptr_list.data(), 1, n_accepted); + + // grads,lapls not on device + // ompBLAS::copy_batched(handle, ndet * nel * DIM, new_grads_acc_ptr_list.data(), 1, grads_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet * nel, new_lapls_acc_ptr_list.data(), 1, lapls_acc_ptr_list.data(), 1, n_accepted); + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(ndet * nel * DIM, new_grads_acc_ptr_list[i], 1, grads_acc_ptr_list[i], 1); + BLAS::copy(ndet * nel, new_lapls_acc_ptr_list[i], 1, lapls_acc_ptr_list[i], 1); + } + + if (wfc_leader.is_spinor_) + { + Vector dspin_psiV_acc_ptr_list(n_accepted); + Vector dspin_psiM_row_acc_ptr_list(n_accepted); + Vector new_spingrads_acc_ptr_list(n_accepted); + Vector spingrads_acc_ptr_list(n_accepted); + + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); + dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data(); + dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb; + new_spingrads_acc_ptr_list[i] = wfc.new_spingrads.data(); + spingrads_acc_ptr_list[i] = wfc.spingrads.data(); + } + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); + BLAS::copy(ndet * nel, new_spingrads_acc_ptr_list[i], 1, spingrads_acc_ptr_list[i], 1); + // ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet * nel, new_spingrads_acc_ptr_list.data(), 1, spingrads_acc_ptr_list.data(), 1, n_accepted); + } + } + for (auto& iacc : idx_Accepted) + { + auto& wfc = wfc_list.getCastedElement(iacc); + wfc.ratios_to_ref_.updateTo(); + wfc.TpsiM.updateTo(); + wfc.psiMinv.updateTo(); + wfc.psiM.updateTo(); + wfc.dpsiM.updateTo(); + } + } + break; + } + // restore: + // setup pointer lists + Vector psiMinv_temp_rej_ptr_list(n_rejected); + Vector psiMinv_rej_ptr_list(n_rejected); + Vector TpsiM_col_rej_ptr_list(n_rejected); + Vector psiM_row_rej_ptr_list(n_rejected); + for (int i = 0; i < n_rejected; i++) + { + auto irej = idx_Rejected[i]; + auto& wfc = wfc_list.getCastedElement(irej); + psiMinv_temp_rej_ptr_list[i] = wfc.psiMinv_temp.data(); + psiMinv_rej_ptr_list[i] = wfc.psiMinv.data(); + TpsiM_col_rej_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex; + psiM_row_rej_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb; + } + + /** + * psiMinv[:,:] -> psiMinv_temp[:,:]; [NumPtcls,NumPtcls] + * psiM[WorkingIndex,:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in other dim) + */ + for (int i = 0; i < n_rejected; i++) + { + BLAS::copy(nel * nel, psiMinv_rej_ptr_list[i], 1, psiMinv_temp_rej_ptr_list[i], 1); + BLAS::copy(norb, psiM_row_rej_ptr_list[i], 1, TpsiM_col_rej_ptr_list[i], nel); + } + // ompBLAS::copy_batched(handle, nel * nel, psiMinv_rej_ptr_list.data(), 1, psiMinv_temp_rej_ptr_list.data(), 1, n_rejected); + // ompBLAS::copy_batched(handle, norb, psiM_row_rej_ptr_list.data(), 1, TpsiM_col_rej_ptr_list.data(), nel, n_rejected); + + for (auto& irej : idx_Rejected) + { + auto& wfc = wfc_list.getCastedElement(irej); + wfc.curRatio = ValueType(1); + wfc.TpsiM.updateTo(); + } } // this has been fixed