From 8488e127420c6b348055f8ec622550be306a99d0 Mon Sep 17 00:00:00 2001 From: Kevin Gasperich Date: Thu, 20 Jun 2024 16:55:33 -0500 Subject: [PATCH 1/7] separate accept/reject walkers in MultiDiracDeterminant --- .../Fermion/MultiDiracDeterminant.cpp | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp index 30b5ab27d7..8efa0999aa 100644 --- a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp +++ b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp @@ -529,12 +529,27 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader& isAccepted) { - // TODO to be expanded to serve offload needs without relying on calling acceptMove and restore + // separate accepted/rejected walker indices + int n_accepted = std::count(isAccepted.begin(), isAccepted.begin() + wfc_list.size(), true); + int n_rejected = wfc_list.size() - n_accepted; + + //TODO: can put these in some preallocated work space (reserve up to n_walkers) + std::vector idx_accepted(n_accepted); + std::vector idx_rejected(n_rejected); + + int iacc = 0; + int irej = 0; for (int iw = 0; iw < wfc_list.size(); iw++) if (isAccepted[iw]) - wfc_list[iw].acceptMove(p_list[iw], iat, false); + idx_accepted[iacc++] = iw; else - wfc_list[iw].restore(iat); + idx_rejected[irej++] = iw; + + // TODO to be expanded to serve offload needs without relying on calling acceptMove and restore + for (auto& i : idx_accepted) + wfc_list[i].acceptMove(p_list[i], iat, false); + for (auto& i : idx_rejected) + wfc_list[i].restore(iat); } // this has been fixed From 018906f3d010b5347b875f51d112ee577e8d6c51 Mon Sep 17 00:00:00 2001 From: Kevin Gasperich Date: Wed, 26 Jun 2024 18:26:15 -0500 Subject: [PATCH 2/7] started mw_acceptreject in MultiDiracDet; needs better handling of grads --- .../Fermion/MultiDiracDeterminant.cpp | 286 +++++++++++++++++- 1 file changed, 272 insertions(+), 14 deletions(-) diff --git a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp index 8efa0999aa..c903a7738c 100644 --- a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp +++ b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp @@ -22,6 +22,7 @@ #include "Message/Communicate.h" #include "Numerics/DeterminantOperators.h" #include "CPU/BLAS.hpp" +#include "OMPTarget/ompBLAS.hpp" #include "Numerics/MatrixOperators.h" #include #include @@ -529,27 +530,284 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader& isAccepted) { + const int nw = wfc_list.size(); // separate accepted/rejected walker indices - int n_accepted = std::count(isAccepted.begin(), isAccepted.begin() + wfc_list.size(), true); - int n_rejected = wfc_list.size() - n_accepted; + const int n_accepted = std::count(isAccepted.begin(), isAccepted.begin() + nw, true); + const int n_rejected = nw - n_accepted; //TODO: can put these in some preallocated work space (reserve up to n_walkers) - std::vector idx_accepted(n_accepted); - std::vector idx_rejected(n_rejected); + std::vector idx_Accepted(n_accepted); + std::vector idx_Rejected(n_rejected); - int iacc = 0; - int irej = 0; - for (int iw = 0; iw < wfc_list.size(); iw++) + for (int iw = 0, iacc = 0, irej = 0; iw < nw; iw++) if (isAccepted[iw]) - idx_accepted[iacc++] = iw; + idx_Accepted[iacc++] = iw; else - idx_rejected[irej++] = iw; + idx_Rejected[irej++] = iw; - // TODO to be expanded to serve offload needs without relying on calling acceptMove and restore - for (auto& i : idx_accepted) - wfc_list[i].acceptMove(p_list[i], iat, false); - for (auto& i : idx_rejected) - wfc_list[i].restore(iat); + + MultiDiracDeterminant& wfc_leader = wfc_list.getLeader(); + ParticleSet& p_leader = p_list.getLeader(); + const int ndet = wfc_leader.getNumDets(); + const int norb = wfc_leader.NumOrbitals; + const int nel = wfc_leader.NumPtcls; + auto& mw_res = wfc_leader.mw_res_handle_.getResource(); + + const int WorkingIndex = iat - wfc_leader.FirstIndex; + assert(WorkingIndex >= 0 && WorkingIndex < wfc_leader.LastIndex - wfc_leader.FirstIndex); + assert(p_leader.isSpinor() == wfc_leader.is_spinor_); + int handle = 0; + for (auto& iacc : idx_Accepted) + { + auto& wfc = wfc_list.getCastedElement(iacc); + if (wfc.curRatio == ValueType(0)) + { + std::ostringstream msg; + msg << "MultiDiracDeterminant::acceptMove curRatio is " << wfc.curRatio << " for walker " << iacc + << "! Report a bug." << std::endl; + throw std::runtime_error(msg.str()); + } + } + + for (auto& iacc : idx_Accepted) + { + auto& wfc = wfc_list.getCastedElement(iacc); + wfc.log_value_ref_det_ += convertValueToLog(wfc.curRatio); + wfc.curRatio = ValueType(1); + } + + // pointers to data for only accepted walkers + OffloadVector psiMinv_temp_acc_deviceptr_list; + OffloadVector psiMinv_acc_deviceptr_list; + OffloadVector psiV_acc_deviceptr_list; + OffloadVector TpsiM_col_acc_deviceptr_list; + OffloadVector psiM_row_acc_deviceptr_list; + OffloadVector new_ratios_to_ref_acc_deviceptr_list; + OffloadVector ratios_to_ref_acc_deviceptr_list; + + OffloadVector dpsiV_acc_deviceptr_list; + OffloadVector dpsiM_row_acc_deviceptr_list; + OffloadVector d2psiV_acc_deviceptr_list; + OffloadVector d2psiM_row_acc_deviceptr_list; + + Vector dspin_psiV_acc_ptr_list; + Vector dspin_psiM_row_acc_ptr_list; + + Vector new_grads_acc_ptr_list; + Vector grads_acc_ptr_list; + Vector new_lapls_acc_ptr_list; + Vector lapls_acc_ptr_list; + Vector new_spingrads_acc_ptr_list; + Vector spingrads_acc_ptr_list; + // cleaner to initialize with correct size, but + + /** + * some of these are in the mw_resource, and some are not + * for the ones that are, get device pointers from the resource collection + * for the ones that aren't, get pointers from MultiDiracDeterminant object + * TODO: I'm assuming here that all data is already up to date on the device before this function is called + */ + + // setup device pointer lists + switch (wfc_leader.UpdateMode) + { + default: + + + new_grads_acc_ptr_list.resize(n_accepted); + grads_acc_ptr_list.resize(n_accepted); + new_lapls_acc_ptr_list.resize(n_accepted); + lapls_acc_ptr_list.resize(n_accepted); + if (wfc_leader.is_spinor_) + { + new_spingrads_acc_ptr_list.resize(n_accepted); + spingrads_acc_ptr_list.resize(n_accepted); + } + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); + new_grads_acc_ptr_list[i] = wfc.new_grads.data(); + grads_acc_ptr_list[i] = wfc.grads.data(); + new_lapls_acc_ptr_list[i] = wfc.new_lapls.data(); + lapls_acc_ptr_list[i] = wfc.lapls.data(); + if (wfc_leader.is_spinor_) + { + new_spingrads_acc_ptr_list[i] = wfc.new_spingrads.data(); + spingrads_acc_ptr_list[i] = wfc.spingrads.data(); + } + } + case ORB_PBYP_PARTIAL: + dpsiV_acc_deviceptr_list.resize(n_accepted); + dpsiM_row_acc_deviceptr_list.resize(n_accepted); + d2psiV_acc_deviceptr_list.resize(n_accepted); + d2psiM_row_acc_deviceptr_list.resize(n_accepted); + dspin_psiV_acc_ptr_list.resize(n_accepted); + dspin_psiM_row_acc_ptr_list.resize(n_accepted); + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); + dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]; + dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc] + WorkingIndex * norb; + d2psiV_acc_deviceptr_list[i] = wfc.d2psiV.device_data(); + d2psiM_row_acc_deviceptr_list[i] = wfc.d2psiM.device_data() + WorkingIndex * norb; + if (wfc_leader.is_spinor_) + { + dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data(); + dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb; + } + } + case ORB_PBYP_RATIO: + psiMinv_temp_acc_deviceptr_list.resize(n_accepted); + psiMinv_acc_deviceptr_list.resize(n_accepted); + psiV_acc_deviceptr_list.resize(n_accepted); + TpsiM_col_acc_deviceptr_list.resize(n_accepted); + psiM_row_acc_deviceptr_list.resize(n_accepted); + new_ratios_to_ref_acc_deviceptr_list.resize(n_accepted); + ratios_to_ref_acc_deviceptr_list.resize(n_accepted); + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + psiMinv_temp_acc_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[iacc]; + psiMinv_acc_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[iacc]; + psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc]; + TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex; + psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb; + + auto& wfc = wfc_list.getCastedElement(iacc); + new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data(); + ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data(); + } + } + + // copy data for accepted walkers + switch (wfc_leader.UpdateMode) + { + case ORB_PBYP_RATIO: + /** + * psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] + * psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim) + * psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] + */ + + ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, + psiMinv_acc_deviceptr_list.data(), 1, n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), norb, + n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, + ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + break; + case ORB_PBYP_PARTIAL: + /** + * psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] + * psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim) + * new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] + * psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType + * d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * if (is_spinor_) + * dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + */ + ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, + psiMinv_acc_deviceptr_list.data(), 1, n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), norb, + n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + // ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), 1, n_accepted); + ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, + ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + + // dspin_psiM/V not on device + if (wfc_leader.is_spinor_) + for (int i = 0; i < n_accepted; i++) + BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); + // ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); + + break; + default: + /** + * psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] + * psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim) + * new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] + * new_grads[:,:] -> grads[:,:]; [NumDets,NumPtcls] GradType + * new_lapls[:,:] -> lapls[:,:]; [NumDets,NumPtcls] + * psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType + * d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * if (is_spinor_) + * dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) + * new_spingrads[:,:] -> spingrads[:,:]; [NumDets,NumPtcls] + */ + ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, + psiMinv_acc_deviceptr_list.data(), 1, n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), norb, + n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + // ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), 1, n_accepted); + ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, + ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + + // grads,lapls not on device + // ompBLAS::copy_batched(handle, ndet * nel * DIM, new_grads_acc_ptr_list.data(), 1, grads_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet * nel, new_lapls_acc_ptr_list.data(), 1, lapls_acc_ptr_list.data(), 1, n_accepted); + for (int i = 0; i < n_accepted; i++) + BLAS::copy(ndet * nel, new_lapls_acc_ptr_list[i], 1, lapls_acc_ptr_list[i], 1); + + if (wfc_leader.is_spinor_) + { + // ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet * nel, new_spingrads_acc_ptr_list.data(), 1, spingrads_acc_ptr_list.data(), 1, n_accepted); + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); + BLAS::copy(ndet * nel, new_spingrads_acc_ptr_list[i], 1, spingrads_acc_ptr_list[i], 1); + } + } + + break; + } + + + // restore: + + // setup pointer lists + OffloadVector psiMinv_temp_rej_deviceptr_list(n_rejected); + OffloadVector psiMinv_rej_deviceptr_list(n_rejected); + OffloadVector TpsiM_col_rej_deviceptr_list(n_rejected); + OffloadVector psiM_row_rej_deviceptr_list(n_rejected); + for (int i = 0; i < n_rejected; i++) + { + auto irej = idx_Rejected[i]; + psiMinv_temp_rej_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[irej]; + psiMinv_rej_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[irej]; + TpsiM_col_rej_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[irej] + WorkingIndex; + psiM_row_rej_deviceptr_list[i] = mw_res.psiM_deviceptr_list[irej] + WorkingIndex * norb; + } + + /** + * psiMinv[:,:] -> psiMinv_temp[:,:]; [NumPtcls,NumPtcls] + * psiM[WorkingIndex,:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in other dim) + */ + ompBLAS::copy_batched(handle, nel * nel, psiMinv_rej_deviceptr_list.data(), 1, psiMinv_temp_rej_deviceptr_list.data(), + 1, n_rejected); + ompBLAS::copy_batched(handle, norb, psiM_row_rej_deviceptr_list.data(), 1, TpsiM_col_rej_deviceptr_list.data(), norb, + n_rejected); + + for (auto& irej : idx_Rejected) + { + auto& wfc = wfc_list.getCastedElement(irej); + wfc.curRatio = ValueType(1); + } } // this has been fixed From 49bd511df00e41c199d8d835ce2d1a2cd817fe87 Mon Sep 17 00:00:00 2001 From: Kevin Gasperich Date: Fri, 28 Jun 2024 15:00:30 -0500 Subject: [PATCH 3/7] fix ld/inc for TpsiM column --- src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp index c903a7738c..1934235999 100644 --- a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp +++ b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp @@ -694,7 +694,7 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader Date: Fri, 28 Jun 2024 17:00:11 -0500 Subject: [PATCH 4/7] fixed grad array pointers --- .../Fermion/MultiDiracDeterminant.cpp | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp index 1934235999..0628c19b0c 100644 --- a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp +++ b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp @@ -585,21 +585,20 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader new_ratios_to_ref_acc_deviceptr_list; OffloadVector ratios_to_ref_acc_deviceptr_list; - OffloadVector dpsiV_acc_deviceptr_list; - OffloadVector dpsiM_row_acc_deviceptr_list; + OffloadVector dpsiV_acc_deviceptr_list; + OffloadVector dpsiM_row_acc_deviceptr_list; OffloadVector d2psiV_acc_deviceptr_list; OffloadVector d2psiM_row_acc_deviceptr_list; Vector dspin_psiV_acc_ptr_list; Vector dspin_psiM_row_acc_ptr_list; - Vector new_grads_acc_ptr_list; - Vector grads_acc_ptr_list; + Vector new_grads_acc_ptr_list; + Vector grads_acc_ptr_list; Vector new_lapls_acc_ptr_list; Vector lapls_acc_ptr_list; Vector new_spingrads_acc_ptr_list; Vector spingrads_acc_ptr_list; - // cleaner to initialize with correct size, but /** * some of these are in the mw_resource, and some are not @@ -612,8 +611,6 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader(iacc); - new_grads_acc_ptr_list[i] = wfc.new_grads.data(); - grads_acc_ptr_list[i] = wfc.grads.data(); + new_grads_acc_ptr_list[i] = wfc.new_grads.data()->data(); + grads_acc_ptr_list[i] = wfc.grads.data()->data(); new_lapls_acc_ptr_list[i] = wfc.new_lapls.data(); lapls_acc_ptr_list[i] = wfc.lapls.data(); if (wfc_leader.is_spinor_) @@ -648,8 +645,8 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader(iacc); - dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]; - dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc] + WorkingIndex * norb; + dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]->data(); + dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc]->data() + WorkingIndex * norb * DIM; d2psiV_acc_deviceptr_list[i] = wfc.d2psiV.device_data(); d2psiM_row_acc_deviceptr_list[i] = wfc.d2psiM.device_data() + WorkingIndex * norb; if (wfc_leader.is_spinor_) @@ -718,7 +715,8 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader psiMinv_temp_rej_deviceptr_list(n_rejected); OffloadVector psiMinv_rej_deviceptr_list(n_rejected); From c216769b7fd8e7c1e46a6d23b1d8fb36a1f7202c Mon Sep 17 00:00:00 2001 From: Kevin Gasperich Date: Mon, 8 Jul 2024 16:29:56 -0500 Subject: [PATCH 5/7] cleanup some logic in mw_accept_rejectMove --- .../Fermion/MultiDiracDeterminant.cpp | 345 ++++++++++-------- 1 file changed, 189 insertions(+), 156 deletions(-) diff --git a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp index 0628c19b0c..ce014eac0c 100644 --- a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp +++ b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp @@ -531,14 +531,16 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader& isAccepted) { const int nw = wfc_list.size(); + assert(isAccepted.size() == nw); // separate accepted/rejected walker indices - const int n_accepted = std::count(isAccepted.begin(), isAccepted.begin() + nw, true); + const int n_accepted = std::count(isAccepted.begin(), isAccepted.end(), true); const int n_rejected = nw - n_accepted; //TODO: can put these in some preallocated work space (reserve up to n_walkers) std::vector idx_Accepted(n_accepted); std::vector idx_Rejected(n_rejected); + // create lists of accepted/rejected walker indices for (int iw = 0, iacc = 0, irej = 0; iw < nw; iw++) if (isAccepted[iw]) idx_Accepted[iacc++] = iw; @@ -576,108 +578,6 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader psiMinv_temp_acc_deviceptr_list; - OffloadVector psiMinv_acc_deviceptr_list; - OffloadVector psiV_acc_deviceptr_list; - OffloadVector TpsiM_col_acc_deviceptr_list; - OffloadVector psiM_row_acc_deviceptr_list; - OffloadVector new_ratios_to_ref_acc_deviceptr_list; - OffloadVector ratios_to_ref_acc_deviceptr_list; - - OffloadVector dpsiV_acc_deviceptr_list; - OffloadVector dpsiM_row_acc_deviceptr_list; - OffloadVector d2psiV_acc_deviceptr_list; - OffloadVector d2psiM_row_acc_deviceptr_list; - - Vector dspin_psiV_acc_ptr_list; - Vector dspin_psiM_row_acc_ptr_list; - - Vector new_grads_acc_ptr_list; - Vector grads_acc_ptr_list; - Vector new_lapls_acc_ptr_list; - Vector lapls_acc_ptr_list; - Vector new_spingrads_acc_ptr_list; - Vector spingrads_acc_ptr_list; - - /** - * some of these are in the mw_resource, and some are not - * for the ones that are, get device pointers from the resource collection - * for the ones that aren't, get pointers from MultiDiracDeterminant object - * TODO: I'm assuming here that all data is already up to date on the device before this function is called - */ - - // setup device pointer lists - switch (wfc_leader.UpdateMode) - { - default: - new_grads_acc_ptr_list.resize(n_accepted); - grads_acc_ptr_list.resize(n_accepted); - new_lapls_acc_ptr_list.resize(n_accepted); - lapls_acc_ptr_list.resize(n_accepted); - if (wfc_leader.is_spinor_) - { - new_spingrads_acc_ptr_list.resize(n_accepted); - spingrads_acc_ptr_list.resize(n_accepted); - } - for (int i = 0; i < n_accepted; i++) - { - auto iacc = idx_Accepted[i]; - auto& wfc = wfc_list.getCastedElement(iacc); - new_grads_acc_ptr_list[i] = wfc.new_grads.data()->data(); - grads_acc_ptr_list[i] = wfc.grads.data()->data(); - new_lapls_acc_ptr_list[i] = wfc.new_lapls.data(); - lapls_acc_ptr_list[i] = wfc.lapls.data(); - if (wfc_leader.is_spinor_) - { - new_spingrads_acc_ptr_list[i] = wfc.new_spingrads.data(); - spingrads_acc_ptr_list[i] = wfc.spingrads.data(); - } - } - case ORB_PBYP_PARTIAL: - dpsiV_acc_deviceptr_list.resize(n_accepted); - dpsiM_row_acc_deviceptr_list.resize(n_accepted); - d2psiV_acc_deviceptr_list.resize(n_accepted); - d2psiM_row_acc_deviceptr_list.resize(n_accepted); - dspin_psiV_acc_ptr_list.resize(n_accepted); - dspin_psiM_row_acc_ptr_list.resize(n_accepted); - for (int i = 0; i < n_accepted; i++) - { - auto iacc = idx_Accepted[i]; - auto& wfc = wfc_list.getCastedElement(iacc); - dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]->data(); - dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc]->data() + WorkingIndex * norb * DIM; - d2psiV_acc_deviceptr_list[i] = wfc.d2psiV.device_data(); - d2psiM_row_acc_deviceptr_list[i] = wfc.d2psiM.device_data() + WorkingIndex * norb; - if (wfc_leader.is_spinor_) - { - dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data(); - dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb; - } - } - case ORB_PBYP_RATIO: - psiMinv_temp_acc_deviceptr_list.resize(n_accepted); - psiMinv_acc_deviceptr_list.resize(n_accepted); - psiV_acc_deviceptr_list.resize(n_accepted); - TpsiM_col_acc_deviceptr_list.resize(n_accepted); - psiM_row_acc_deviceptr_list.resize(n_accepted); - new_ratios_to_ref_acc_deviceptr_list.resize(n_accepted); - ratios_to_ref_acc_deviceptr_list.resize(n_accepted); - for (int i = 0; i < n_accepted; i++) - { - auto iacc = idx_Accepted[i]; - psiMinv_temp_acc_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[iacc]; - psiMinv_acc_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[iacc]; - psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc]; - TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex; - psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb; - - auto& wfc = wfc_list.getCastedElement(iacc); - new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data(); - ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data(); - } - } - // copy data for accepted walkers switch (wfc_leader.UpdateMode) { @@ -688,16 +588,43 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) * new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] */ + { + OffloadVector psiMinv_temp_acc_deviceptr_list(n_accepted); + OffloadVector psiMinv_acc_deviceptr_list(n_accepted); + + OffloadVector psiV_acc_deviceptr_list(n_accepted); + OffloadVector TpsiM_col_acc_deviceptr_list(n_accepted); + OffloadVector psiM_row_acc_deviceptr_list(n_accepted); + + OffloadVector new_ratios_to_ref_acc_deviceptr_list(n_accepted); + OffloadVector ratios_to_ref_acc_deviceptr_list(n_accepted); + + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + psiMinv_temp_acc_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[iacc]; + psiMinv_acc_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[iacc]; + + psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc]; + TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex; + psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb; - ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, - psiMinv_acc_deviceptr_list.data(), 1, n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel, - n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, - ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + auto& wfc = wfc_list.getCastedElement(iacc); + new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data(); + ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data(); + } + + ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, + psiMinv_acc_deviceptr_list.data(), 1, n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel, + n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, + ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + } break; + case ORB_PBYP_PARTIAL: /** * psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] @@ -709,26 +636,73 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) */ - ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, - psiMinv_acc_deviceptr_list.data(), 1, n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel, - n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), - 1, n_accepted); - ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, - ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); - - // dspin_psiM/V not on device - if (wfc_leader.is_spinor_) - for (int i = 0; i < n_accepted; i++) - BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); - // ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); + { + OffloadVector psiMinv_temp_acc_deviceptr_list(n_accepted); + OffloadVector psiMinv_acc_deviceptr_list(n_accepted); + + OffloadVector psiV_acc_deviceptr_list(n_accepted); + OffloadVector TpsiM_col_acc_deviceptr_list(n_accepted); + OffloadVector psiM_row_acc_deviceptr_list(n_accepted); + + OffloadVector new_ratios_to_ref_acc_deviceptr_list(n_accepted); + OffloadVector ratios_to_ref_acc_deviceptr_list(n_accepted); + OffloadVector dpsiV_acc_deviceptr_list(n_accepted); + OffloadVector dpsiM_row_acc_deviceptr_list(n_accepted); + OffloadVector d2psiV_acc_deviceptr_list(n_accepted); + OffloadVector d2psiM_row_acc_deviceptr_list(n_accepted); + + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + psiMinv_temp_acc_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[iacc]; + psiMinv_acc_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[iacc]; + + psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc]; + TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex; + psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb; + + auto& wfc = wfc_list.getCastedElement(iacc); + new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data(); + ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data(); + + dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]->data(); + dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc]->data() + WorkingIndex * norb * DIM; + d2psiV_acc_deviceptr_list[i] = wfc.d2psiV.device_data(); + d2psiM_row_acc_deviceptr_list[i] = wfc.d2psiM.device_data() + WorkingIndex * norb; + } + ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, + psiMinv_acc_deviceptr_list.data(), 1, n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel, + n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), + 1, n_accepted); + ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, + ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + + // dspin_psiM/V not on device + if (wfc_leader.is_spinor_) + { + Vector dspin_psiV_acc_ptr_list(n_accepted); + Vector dspin_psiM_row_acc_ptr_list(n_accepted); + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); + dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data(); + dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb; + } + for (int i = 0; i < n_accepted; i++) + BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); + // ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); + } + } break; + default: /** * psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] @@ -743,39 +717,98 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) * new_spingrads[:,:] -> spingrads[:,:]; [NumDets,NumPtcls] */ - ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, - psiMinv_acc_deviceptr_list.data(), 1, n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), norb, - n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), - 1, n_accepted); - ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, - ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); - - // grads,lapls not on device - // ompBLAS::copy_batched(handle, ndet * nel * DIM, new_grads_acc_ptr_list.data(), 1, grads_acc_ptr_list.data(), 1, n_accepted); - // ompBLAS::copy_batched(handle, ndet * nel, new_lapls_acc_ptr_list.data(), 1, lapls_acc_ptr_list.data(), 1, n_accepted); - for (int i = 0; i < n_accepted; i++) { - BLAS::copy(ndet * nel * DIM, new_grads_acc_ptr_list[i], 1, grads_acc_ptr_list[i], 1); - BLAS::copy(ndet * nel, new_lapls_acc_ptr_list[i], 1, lapls_acc_ptr_list[i], 1); - } + OffloadVector psiMinv_temp_acc_deviceptr_list(n_accepted); + OffloadVector psiMinv_acc_deviceptr_list(n_accepted); + + OffloadVector psiV_acc_deviceptr_list(n_accepted); + OffloadVector TpsiM_col_acc_deviceptr_list(n_accepted); + OffloadVector psiM_row_acc_deviceptr_list(n_accepted); + + OffloadVector new_ratios_to_ref_acc_deviceptr_list(n_accepted); + OffloadVector ratios_to_ref_acc_deviceptr_list(n_accepted); + + OffloadVector dpsiV_acc_deviceptr_list(n_accepted); + OffloadVector dpsiM_row_acc_deviceptr_list(n_accepted); + OffloadVector d2psiV_acc_deviceptr_list(n_accepted); + OffloadVector d2psiM_row_acc_deviceptr_list(n_accepted); + + Vector new_grads_acc_ptr_list(n_accepted); + Vector grads_acc_ptr_list(n_accepted); + Vector new_lapls_acc_ptr_list(n_accepted); + Vector lapls_acc_ptr_list(n_accepted); - if (wfc_leader.is_spinor_) - { - // ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); - // ompBLAS::copy_batched(handle, ndet * nel, new_spingrads_acc_ptr_list.data(), 1, spingrads_acc_ptr_list.data(), 1, n_accepted); for (int i = 0; i < n_accepted; i++) { - BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); - BLAS::copy(ndet * nel, new_spingrads_acc_ptr_list[i], 1, spingrads_acc_ptr_list[i], 1); + auto iacc = idx_Accepted[i]; + psiMinv_temp_acc_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[iacc]; + psiMinv_acc_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[iacc]; + + psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc]; + TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex; + psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb; + + auto& wfc = wfc_list.getCastedElement(iacc); + new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data(); + ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data(); + + dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]->data(); + dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc]->data() + WorkingIndex * norb * DIM; + d2psiV_acc_deviceptr_list[i] = wfc.d2psiV.device_data(); + d2psiM_row_acc_deviceptr_list[i] = wfc.d2psiM.device_data() + WorkingIndex * norb; + + new_grads_acc_ptr_list[i] = wfc.new_grads.data()->data(); + grads_acc_ptr_list[i] = wfc.grads.data()->data(); + new_lapls_acc_ptr_list[i] = wfc.new_lapls.data(); + lapls_acc_ptr_list[i] = wfc.lapls.data(); + } + ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, + psiMinv_acc_deviceptr_list.data(), 1, n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), norb, + n_accepted); + ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), + 1, n_accepted); + ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, + n_accepted); + ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, + ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + + // grads,lapls not on device + // ompBLAS::copy_batched(handle, ndet * nel * DIM, new_grads_acc_ptr_list.data(), 1, grads_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet * nel, new_lapls_acc_ptr_list.data(), 1, lapls_acc_ptr_list.data(), 1, n_accepted); + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(ndet * nel * DIM, new_grads_acc_ptr_list[i], 1, grads_acc_ptr_list[i], 1); + BLAS::copy(ndet * nel, new_lapls_acc_ptr_list[i], 1, lapls_acc_ptr_list[i], 1); } - } + if (wfc_leader.is_spinor_) + { + Vector dspin_psiV_acc_ptr_list(n_accepted); + Vector dspin_psiM_row_acc_ptr_list(n_accepted); + Vector new_spingrads_acc_ptr_list(n_accepted); + Vector spingrads_acc_ptr_list(n_accepted); + + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); + dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data(); + dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb; + new_spingrads_acc_ptr_list[i] = wfc.new_spingrads.data(); + spingrads_acc_ptr_list[i] = wfc.spingrads.data(); + } + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); + BLAS::copy(ndet * nel, new_spingrads_acc_ptr_list[i], 1, spingrads_acc_ptr_list[i], 1); + // ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet * nel, new_spingrads_acc_ptr_list.data(), 1, spingrads_acc_ptr_list.data(), 1, n_accepted); + } + } + } break; } From 42986f42d5742cf527e1de4fa79b0f70fab291b4 Mon Sep 17 00:00:00 2001 From: Kevin Gasperich Date: Mon, 19 Aug 2024 14:53:02 -0500 Subject: [PATCH 6/7] keep old acc/rej data movement behavior --- .../Fermion/MultiDiracDeterminant.cpp | 252 ++++++++++-------- 1 file changed, 139 insertions(+), 113 deletions(-) diff --git a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp index ce014eac0c..771af33ef9 100644 --- a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp +++ b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp @@ -589,39 +589,42 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader ratios_to_ref_[:]; [NumDets] */ { - OffloadVector psiMinv_temp_acc_deviceptr_list(n_accepted); - OffloadVector psiMinv_acc_deviceptr_list(n_accepted); + Vector psiMinv_temp_acc_ptr_list(n_accepted); + Vector psiMinv_acc_ptr_list(n_accepted); - OffloadVector psiV_acc_deviceptr_list(n_accepted); - OffloadVector TpsiM_col_acc_deviceptr_list(n_accepted); - OffloadVector psiM_row_acc_deviceptr_list(n_accepted); + Vector psiV_acc_ptr_list(n_accepted); + Vector TpsiM_col_acc_ptr_list(n_accepted); + Vector psiM_row_acc_ptr_list(n_accepted); - OffloadVector new_ratios_to_ref_acc_deviceptr_list(n_accepted); - OffloadVector ratios_to_ref_acc_deviceptr_list(n_accepted); + Vector new_ratios_to_ref_acc_ptr_list(n_accepted); + Vector ratios_to_ref_acc_ptr_list(n_accepted); for (int i = 0; i < n_accepted; i++) { - auto iacc = idx_Accepted[i]; - psiMinv_temp_acc_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[iacc]; - psiMinv_acc_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[iacc]; + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); - psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc]; - TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex; - psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb; + psiMinv_temp_acc_ptr_list[i] = wfc.psiMinv_temp.data(); + psiMinv_acc_ptr_list[i] = wfc.psiMinv.data(); - auto& wfc = wfc_list.getCastedElement(iacc); - new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data(); - ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data(); - } + psiV_acc_ptr_list[i] = wfc.psiV.data(); + TpsiM_col_acc_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex; + psiM_row_acc_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb; - ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, - psiMinv_acc_deviceptr_list.data(), 1, n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel, - n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, - ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + new_ratios_to_ref_acc_ptr_list[i] = wfc.new_ratios_to_ref_.data(); + ratios_to_ref_acc_ptr_list[i] = wfc.ratios_to_ref_.data(); + } + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(nel * nel, psiMinv_temp_acc_ptr_list[i], 1, psiMinv_acc_ptr_list[i], 1); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, TpsiM_col_acc_ptr_list[i], nel); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, psiM_row_acc_ptr_list[i], 1); + BLAS::copy(ndet, new_ratios_to_ref_acc_ptr_list[i], 1, ratios_to_ref_acc_ptr_list[i], 1); + } + // ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_ptr_list.data(), 1, psiMinv_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, TpsiM_col_acc_ptr_list.data(), nel, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_ptr_list.data(), 1, ratios_to_ref_acc_ptr_list.data(), 1, n_accepted); } break; @@ -637,52 +640,56 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) */ { - OffloadVector psiMinv_temp_acc_deviceptr_list(n_accepted); - OffloadVector psiMinv_acc_deviceptr_list(n_accepted); + Vector psiMinv_temp_acc_ptr_list(n_accepted); + Vector psiMinv_acc_ptr_list(n_accepted); + + Vector psiV_acc_ptr_list(n_accepted); + Vector TpsiM_col_acc_ptr_list(n_accepted); + Vector psiM_row_acc_ptr_list(n_accepted); + + Vector new_ratios_to_ref_acc_ptr_list(n_accepted); + Vector ratios_to_ref_acc_ptr_list(n_accepted); + + Vector dpsiV_acc_ptr_list(n_accepted); + Vector dpsiM_row_acc_ptr_list(n_accepted); + Vector d2psiV_acc_ptr_list(n_accepted); + Vector d2psiM_row_acc_ptr_list(n_accepted); + + for (int i = 0; i < n_accepted; i++) + { + auto iacc = idx_Accepted[i]; + auto& wfc = wfc_list.getCastedElement(iacc); - OffloadVector psiV_acc_deviceptr_list(n_accepted); - OffloadVector TpsiM_col_acc_deviceptr_list(n_accepted); - OffloadVector psiM_row_acc_deviceptr_list(n_accepted); + psiMinv_temp_acc_ptr_list[i] = wfc.psiMinv_temp.data(); + psiMinv_acc_ptr_list[i] = wfc.psiMinv.data(); - OffloadVector new_ratios_to_ref_acc_deviceptr_list(n_accepted); - OffloadVector ratios_to_ref_acc_deviceptr_list(n_accepted); + psiV_acc_ptr_list[i] = wfc.psiV.data(); + TpsiM_col_acc_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex; + psiM_row_acc_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb; - OffloadVector dpsiV_acc_deviceptr_list(n_accepted); - OffloadVector dpsiM_row_acc_deviceptr_list(n_accepted); - OffloadVector d2psiV_acc_deviceptr_list(n_accepted); - OffloadVector d2psiM_row_acc_deviceptr_list(n_accepted); + new_ratios_to_ref_acc_ptr_list[i] = wfc.new_ratios_to_ref_.data(); + ratios_to_ref_acc_ptr_list[i] = wfc.ratios_to_ref_.data(); + dpsiV_acc_ptr_list[i] = wfc.dpsiV.data()->data(); + dpsiM_row_acc_ptr_list[i] = wfc.dpsiM.data()->data() + WorkingIndex * norb * DIM; + d2psiV_acc_ptr_list[i] = wfc.d2psiV.data(); + d2psiM_row_acc_ptr_list[i] = wfc.d2psiM.data() + WorkingIndex * norb; + } for (int i = 0; i < n_accepted; i++) { - auto iacc = idx_Accepted[i]; - psiMinv_temp_acc_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[iacc]; - psiMinv_acc_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[iacc]; - - psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc]; - TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex; - psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb; - - auto& wfc = wfc_list.getCastedElement(iacc); - new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data(); - ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data(); - - dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]->data(); - dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc]->data() + WorkingIndex * norb * DIM; - d2psiV_acc_deviceptr_list[i] = wfc.d2psiV.device_data(); - d2psiM_row_acc_deviceptr_list[i] = wfc.d2psiM.device_data() + WorkingIndex * norb; + BLAS::copy(nel * nel, psiMinv_temp_acc_ptr_list[i], 1, psiMinv_acc_ptr_list[i], 1); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, TpsiM_col_acc_ptr_list[i], nel); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, psiM_row_acc_ptr_list[i], 1); + BLAS::copy(norb * DIM, dpsiV_acc_ptr_list[i], 1, dpsiM_row_acc_ptr_list[i], 1); + BLAS::copy(norb, d2psiV_acc_ptr_list[i], 1, d2psiM_row_acc_ptr_list[i], 1); + BLAS::copy(ndet, new_ratios_to_ref_acc_ptr_list[i], 1, ratios_to_ref_acc_ptr_list[i], 1); } - ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, - psiMinv_acc_deviceptr_list.data(), 1, n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel, - n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), - 1, n_accepted); - ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, - ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_ptr_list.data(), 1, psiMinv_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, TpsiM_col_acc_ptr_list.data(), nel, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_ptr_list.data(), 1, dpsiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, d2psiV_acc_ptr_list.data(), 1, d2psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_ptr_list.data(), 1, ratios_to_ref_acc_ptr_list.data(), 1, n_accepted); // dspin_psiM/V not on device if (wfc_leader.is_spinor_) @@ -718,20 +725,20 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader spingrads[:,:]; [NumDets,NumPtcls] */ { - OffloadVector psiMinv_temp_acc_deviceptr_list(n_accepted); - OffloadVector psiMinv_acc_deviceptr_list(n_accepted); + Vector psiMinv_temp_acc_ptr_list(n_accepted); + Vector psiMinv_acc_ptr_list(n_accepted); - OffloadVector psiV_acc_deviceptr_list(n_accepted); - OffloadVector TpsiM_col_acc_deviceptr_list(n_accepted); - OffloadVector psiM_row_acc_deviceptr_list(n_accepted); + Vector psiV_acc_ptr_list(n_accepted); + Vector TpsiM_col_acc_ptr_list(n_accepted); + Vector psiM_row_acc_ptr_list(n_accepted); - OffloadVector new_ratios_to_ref_acc_deviceptr_list(n_accepted); - OffloadVector ratios_to_ref_acc_deviceptr_list(n_accepted); + Vector new_ratios_to_ref_acc_ptr_list(n_accepted); + Vector ratios_to_ref_acc_ptr_list(n_accepted); - OffloadVector dpsiV_acc_deviceptr_list(n_accepted); - OffloadVector dpsiM_row_acc_deviceptr_list(n_accepted); - OffloadVector d2psiV_acc_deviceptr_list(n_accepted); - OffloadVector d2psiM_row_acc_deviceptr_list(n_accepted); + Vector dpsiV_acc_ptr_list(n_accepted); + Vector dpsiM_row_acc_ptr_list(n_accepted); + Vector d2psiV_acc_ptr_list(n_accepted); + Vector d2psiM_row_acc_ptr_list(n_accepted); Vector new_grads_acc_ptr_list(n_accepted); Vector grads_acc_ptr_list(n_accepted); @@ -740,40 +747,44 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader(iacc); + + psiMinv_temp_acc_ptr_list[i] = wfc.psiMinv_temp.data(); + psiMinv_acc_ptr_list[i] = wfc.psiMinv.data(); - psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc]; - TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex; - psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb; + psiV_acc_ptr_list[i] = wfc.psiV.data(); + TpsiM_col_acc_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex; + psiM_row_acc_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb; - auto& wfc = wfc_list.getCastedElement(iacc); - new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data(); - ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data(); + new_ratios_to_ref_acc_ptr_list[i] = wfc.new_ratios_to_ref_.data(); + ratios_to_ref_acc_ptr_list[i] = wfc.ratios_to_ref_.data(); - dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]->data(); - dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc]->data() + WorkingIndex * norb * DIM; - d2psiV_acc_deviceptr_list[i] = wfc.d2psiV.device_data(); - d2psiM_row_acc_deviceptr_list[i] = wfc.d2psiM.device_data() + WorkingIndex * norb; + dpsiV_acc_ptr_list[i] = wfc.dpsiV.data()->data(); + dpsiM_row_acc_ptr_list[i] = wfc.dpsiM.data()->data() + WorkingIndex * norb * DIM; + d2psiV_acc_ptr_list[i] = wfc.d2psiV.data(); + d2psiM_row_acc_ptr_list[i] = wfc.d2psiM.data() + WorkingIndex * norb; new_grads_acc_ptr_list[i] = wfc.new_grads.data()->data(); grads_acc_ptr_list[i] = wfc.grads.data()->data(); new_lapls_acc_ptr_list[i] = wfc.new_lapls.data(); lapls_acc_ptr_list[i] = wfc.lapls.data(); } - ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, - psiMinv_acc_deviceptr_list.data(), 1, n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), norb, - n_accepted); - ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), - 1, n_accepted); - ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, - n_accepted); - ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, - ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); + for (int i = 0; i < n_accepted; i++) + { + BLAS::copy(nel * nel, psiMinv_temp_acc_ptr_list[i], 1, psiMinv_acc_ptr_list[i], 1); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, TpsiM_col_acc_ptr_list[i], nel); + BLAS::copy(norb, psiV_acc_ptr_list[i], 1, psiM_row_acc_ptr_list[i], 1); + BLAS::copy(norb * DIM, dpsiV_acc_ptr_list[i], 1, dpsiM_row_acc_ptr_list[i], 1); + BLAS::copy(norb, d2psiV_acc_ptr_list[i], 1, d2psiM_row_acc_ptr_list[i], 1); + BLAS::copy(ndet, new_ratios_to_ref_acc_ptr_list[i], 1, ratios_to_ref_acc_ptr_list[i], 1); + } + // ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_ptr_list.data(), 1, psiMinv_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, TpsiM_col_acc_ptr_list.data(), nel, n_accepted); + // ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_ptr_list.data(), 1, dpsiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, norb, d2psiV_acc_ptr_list.data(), 1, d2psiM_row_acc_ptr_list.data(), 1, n_accepted); + // ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_ptr_list.data(), 1, ratios_to_ref_acc_ptr_list.data(), 1, n_accepted); // grads,lapls not on device // ompBLAS::copy_batched(handle, ndet * nel * DIM, new_grads_acc_ptr_list.data(), 1, grads_acc_ptr_list.data(), 1, n_accepted); @@ -812,34 +823,49 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader(iacc); + wfc.ratios_to_ref_.updateTo(); + wfc.TpsiM.updateTo(); + wfc.psiMinv.updateTo(); + wfc.psiM.updateTo(); + wfc.dpsiM.updateTo(); + } // restore: // setup pointer lists - OffloadVector psiMinv_temp_rej_deviceptr_list(n_rejected); - OffloadVector psiMinv_rej_deviceptr_list(n_rejected); - OffloadVector TpsiM_col_rej_deviceptr_list(n_rejected); - OffloadVector psiM_row_rej_deviceptr_list(n_rejected); + Vector psiMinv_temp_rej_ptr_list(n_rejected); + Vector psiMinv_rej_ptr_list(n_rejected); + Vector TpsiM_col_rej_ptr_list(n_rejected); + Vector psiM_row_rej_ptr_list(n_rejected); for (int i = 0; i < n_rejected; i++) { - auto irej = idx_Rejected[i]; - psiMinv_temp_rej_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[irej]; - psiMinv_rej_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[irej]; - TpsiM_col_rej_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[irej] + WorkingIndex; - psiM_row_rej_deviceptr_list[i] = mw_res.psiM_deviceptr_list[irej] + WorkingIndex * norb; + auto irej = idx_Rejected[i]; + auto& wfc = wfc_list.getCastedElement(irej); + psiMinv_temp_rej_ptr_list[i] = wfc.psiMinv_temp.data(); + psiMinv_rej_ptr_list[i] = wfc.psiMinv.data(); + TpsiM_col_rej_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex; + psiM_row_rej_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb; } /** * psiMinv[:,:] -> psiMinv_temp[:,:]; [NumPtcls,NumPtcls] * psiM[WorkingIndex,:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in other dim) */ - ompBLAS::copy_batched(handle, nel * nel, psiMinv_rej_deviceptr_list.data(), 1, psiMinv_temp_rej_deviceptr_list.data(), - 1, n_rejected); - ompBLAS::copy_batched(handle, norb, psiM_row_rej_deviceptr_list.data(), 1, TpsiM_col_rej_deviceptr_list.data(), nel, - n_rejected); + for (int i = 0; i < n_rejected; i++) + { + BLAS::copy(nel * nel, psiMinv_rej_ptr_list[i], 1, psiMinv_temp_rej_ptr_list[i], 1); + BLAS::copy(norb, psiM_row_rej_ptr_list[i], 1, TpsiM_col_rej_ptr_list[i], nel); + } + // ompBLAS::copy_batched(handle, nel * nel, psiMinv_rej_ptr_list.data(), 1, psiMinv_temp_rej_ptr_list.data(), 1, n_rejected); + // ompBLAS::copy_batched(handle, norb, psiM_row_rej_ptr_list.data(), 1, TpsiM_col_rej_ptr_list.data(), nel, n_rejected); for (auto& irej : idx_Rejected) { auto& wfc = wfc_list.getCastedElement(irej); wfc.curRatio = ValueType(1); + wfc.TpsiM.updateTo(); } } From 1d94b8ce51b7ee6a28ddc4bfed0289d1ff32e95f Mon Sep 17 00:00:00 2001 From: Kevin Gasperich Date: Mon, 19 Aug 2024 16:57:20 -0500 Subject: [PATCH 7/7] moved H2D into each case --- .../Fermion/MultiDiracDeterminant.cpp | 45 +++++++++++++------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp index 771af33ef9..5fd485541f 100644 --- a/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp +++ b/src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp @@ -625,6 +625,16 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader(iacc); + wfc.ratios_to_ref_.updateTo(); + wfc.TpsiM.updateTo(); + wfc.psiMinv.updateTo(); + wfc.psiM.updateTo(); + // dpsiM not updated on host in this case, but this H2D update is included for consistency with single-walker acceptMove + wfc.dpsiM.updateTo(); + } } break; @@ -665,7 +675,7 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader(iacc); + wfc.ratios_to_ref_.updateTo(); + wfc.TpsiM.updateTo(); + wfc.psiMinv.updateTo(); + wfc.psiM.updateTo(); + wfc.dpsiM.updateTo(); + } } break; @@ -819,20 +838,18 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader(iacc); + wfc.ratios_to_ref_.updateTo(); + wfc.TpsiM.updateTo(); + wfc.psiMinv.updateTo(); + wfc.psiM.updateTo(); + wfc.dpsiM.updateTo(); + } } break; } - - for (int i = 0; i < n_accepted; i++) - { - auto iacc = idx_Accepted[i]; - auto& wfc = wfc_list.getCastedElement(iacc); - wfc.ratios_to_ref_.updateTo(); - wfc.TpsiM.updateTo(); - wfc.psiMinv.updateTo(); - wfc.psiM.updateTo(); - wfc.dpsiM.updateTo(); - } // restore: // setup pointer lists Vector psiMinv_temp_rej_ptr_list(n_rejected); @@ -841,8 +858,8 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader psiM_row_rej_ptr_list(n_rejected); for (int i = 0; i < n_rejected; i++) { - auto irej = idx_Rejected[i]; - auto& wfc = wfc_list.getCastedElement(irej); + auto irej = idx_Rejected[i]; + auto& wfc = wfc_list.getCastedElement(irej); psiMinv_temp_rej_ptr_list[i] = wfc.psiMinv_temp.data(); psiMinv_rej_ptr_list[i] = wfc.psiMinv.data(); TpsiM_col_rej_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex;