-
Notifications
You must be signed in to change notification settings - Fork 139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
batching in MultiDiracDeterminant::mw_accept_rejectMove #5071
Changes from 4 commits
8488e12
018906f
49bd511
2c9b71d
c216769
42986f4
1d94b8c
25ca1dd
8a2f6ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
#include "Message/Communicate.h" | ||
#include "Numerics/DeterminantOperators.h" | ||
#include "CPU/BLAS.hpp" | ||
#include "OMPTarget/ompBLAS.hpp" | ||
#include "Numerics/MatrixOperators.h" | ||
#include <algorithm> | ||
#include <vector> | ||
|
@@ -529,12 +530,284 @@ | |
int iat, | ||
const std::vector<bool>& isAccepted) | ||
{ | ||
// TODO to be expanded to serve offload needs without relying on calling acceptMove and restore | ||
for (int iw = 0; iw < wfc_list.size(); iw++) | ||
const int nw = wfc_list.size(); | ||
// separate accepted/rejected walker indices | ||
const int n_accepted = std::count(isAccepted.begin(), isAccepted.begin() + nw, true); | ||
const int n_rejected = nw - n_accepted; | ||
|
||
//TODO: can put these in some preallocated work space (reserve up to n_walkers) | ||
std::vector<int> idx_Accepted(n_accepted); | ||
std::vector<int> idx_Rejected(n_rejected); | ||
|
||
for (int iw = 0, iacc = 0, irej = 0; iw < nw; iw++) | ||
if (isAccepted[iw]) | ||
wfc_list[iw].acceptMove(p_list[iw], iat, false); | ||
idx_Accepted[iacc++] = iw; | ||
else | ||
wfc_list[iw].restore(iat); | ||
idx_Rejected[irej++] = iw; | ||
|
||
|
||
MultiDiracDeterminant& wfc_leader = wfc_list.getLeader(); | ||
ParticleSet& p_leader = p_list.getLeader(); | ||
const int ndet = wfc_leader.getNumDets(); | ||
const int norb = wfc_leader.NumOrbitals; | ||
const int nel = wfc_leader.NumPtcls; | ||
auto& mw_res = wfc_leader.mw_res_handle_.getResource(); | ||
|
||
const int WorkingIndex = iat - wfc_leader.FirstIndex; | ||
assert(WorkingIndex >= 0 && WorkingIndex < wfc_leader.LastIndex - wfc_leader.FirstIndex); | ||
assert(p_leader.isSpinor() == wfc_leader.is_spinor_); | ||
int handle = 0; | ||
for (auto& iacc : idx_Accepted) | ||
{ | ||
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc); | ||
if (wfc.curRatio == ValueType(0)) | ||
{ | ||
std::ostringstream msg; | ||
msg << "MultiDiracDeterminant::acceptMove curRatio is " << wfc.curRatio << " for walker " << iacc | ||
<< "! Report a bug." << std::endl; | ||
throw std::runtime_error(msg.str()); | ||
} | ||
} | ||
|
||
for (auto& iacc : idx_Accepted) | ||
{ | ||
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc); | ||
wfc.log_value_ref_det_ += convertValueToLog(wfc.curRatio); | ||
wfc.curRatio = ValueType(1); | ||
} | ||
|
||
// pointers to data for only accepted walkers | ||
OffloadVector<ValueType*> psiMinv_temp_acc_deviceptr_list; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since there aren't retained I'm not sure why the copy and pointer list building cases aren't fused and these are just in those cases. Put blocks in the cases and these lists will be constructed only if you are actually using them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, this got a bit messy; I got stuck between not wanting to duplicate code but also not wanting to make the fallthrough logic of the switches hard to follow, so I retained the original logic for the actual copying and used different logic for the pointer lists to avoid duplication (and to only resize the ones I'm actually using), and I ended up with something worse than if I'd just committed to one or the other |
||
OffloadVector<ValueType*> psiMinv_acc_deviceptr_list; | ||
OffloadVector<ValueType*> psiV_acc_deviceptr_list; | ||
OffloadVector<ValueType*> TpsiM_col_acc_deviceptr_list; | ||
OffloadVector<ValueType*> psiM_row_acc_deviceptr_list; | ||
OffloadVector<ValueType*> new_ratios_to_ref_acc_deviceptr_list; | ||
OffloadVector<ValueType*> ratios_to_ref_acc_deviceptr_list; | ||
|
||
OffloadVector<ValueType*> dpsiV_acc_deviceptr_list; | ||
OffloadVector<ValueType*> dpsiM_row_acc_deviceptr_list; | ||
OffloadVector<ValueType*> d2psiV_acc_deviceptr_list; | ||
OffloadVector<ValueType*> d2psiM_row_acc_deviceptr_list; | ||
|
||
Vector<ValueType*> dspin_psiV_acc_ptr_list; | ||
Vector<ValueType*> dspin_psiM_row_acc_ptr_list; | ||
|
||
Vector<ValueType*> new_grads_acc_ptr_list; | ||
Vector<ValueType*> grads_acc_ptr_list; | ||
Vector<ValueType*> new_lapls_acc_ptr_list; | ||
Vector<ValueType*> lapls_acc_ptr_list; | ||
Vector<ValueType*> new_spingrads_acc_ptr_list; | ||
Vector<ValueType*> spingrads_acc_ptr_list; | ||
|
||
/** | ||
* some of these are in the mw_resource, and some are not | ||
* for the ones that are, get device pointers from the resource collection | ||
* for the ones that aren't, get pointers from MultiDiracDeterminant object | ||
* TODO: I'm assuming here that all data is already up to date on the device before this function is called | ||
*/ | ||
|
||
// setup device pointer lists | ||
switch (wfc_leader.UpdateMode) | ||
{ | ||
default: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because there are no real unit tests for this class and just a massive procedure test for multislater determinant. I'm not sure whether this coverage warning means none of this default case is covered. My reading is the default case is not covered. I think the ORB_PYPB _PARTIAL/RATIOS cases are covered. Seems like a bunch of important code with no coverage. Before doing a big refactor why not write the missing unit tests. What's the expected outcome of calling this function? It seems like this function basically just copies a bunch of state data in response to the isAccepted vector and the internal WorkingIndex state. Significantly there are three modes of this. Fill members with test data, intialize WorkingIndex, call, check side_effects are correct. You will probably need a testing friend class to break the encapsulation. See src/Estimators/OneBodyDensityMatricesTests.h There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
My plan going into this refactor was that it wasn't going to be a huge change, and that I was just planning to group some similar work together (the rejects and accepts) without any other change in behavior. I don't yet have a complete picture of everything else upstream and downstream from this function and what different cases might need to be handled.
This is also my take on what's happening, but the part I'm not sure about (and what I asked about in my comment above) is which state data I can assume to be up-to-date on the host/device before entering this function, and what needs to be up-to-date (and where) after exiting. |
||
new_grads_acc_ptr_list.resize(n_accepted); | ||
grads_acc_ptr_list.resize(n_accepted); | ||
new_lapls_acc_ptr_list.resize(n_accepted); | ||
lapls_acc_ptr_list.resize(n_accepted); | ||
if (wfc_leader.is_spinor_) | ||
{ | ||
new_spingrads_acc_ptr_list.resize(n_accepted); | ||
spingrads_acc_ptr_list.resize(n_accepted); | ||
} | ||
for (int i = 0; i < n_accepted; i++) | ||
{ | ||
auto iacc = idx_Accepted[i]; | ||
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc); | ||
new_grads_acc_ptr_list[i] = wfc.new_grads.data()->data(); | ||
grads_acc_ptr_list[i] = wfc.grads.data()->data(); | ||
new_lapls_acc_ptr_list[i] = wfc.new_lapls.data(); | ||
lapls_acc_ptr_list[i] = wfc.lapls.data(); | ||
if (wfc_leader.is_spinor_) | ||
{ | ||
new_spingrads_acc_ptr_list[i] = wfc.new_spingrads.data(); | ||
spingrads_acc_ptr_list[i] = wfc.spingrads.data(); | ||
} | ||
} | ||
case ORB_PBYP_PARTIAL: | ||
dpsiV_acc_deviceptr_list.resize(n_accepted); | ||
dpsiM_row_acc_deviceptr_list.resize(n_accepted); | ||
d2psiV_acc_deviceptr_list.resize(n_accepted); | ||
d2psiM_row_acc_deviceptr_list.resize(n_accepted); | ||
dspin_psiV_acc_ptr_list.resize(n_accepted); | ||
dspin_psiM_row_acc_ptr_list.resize(n_accepted); | ||
for (int i = 0; i < n_accepted; i++) | ||
{ | ||
auto iacc = idx_Accepted[i]; | ||
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc); | ||
dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]->data(); | ||
dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc]->data() + WorkingIndex * norb * DIM; | ||
d2psiV_acc_deviceptr_list[i] = wfc.d2psiV.device_data(); | ||
d2psiM_row_acc_deviceptr_list[i] = wfc.d2psiM.device_data() + WorkingIndex * norb; | ||
if (wfc_leader.is_spinor_) | ||
{ | ||
dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data(); | ||
dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb; | ||
} | ||
} | ||
case ORB_PBYP_RATIO: | ||
psiMinv_temp_acc_deviceptr_list.resize(n_accepted); | ||
psiMinv_acc_deviceptr_list.resize(n_accepted); | ||
psiV_acc_deviceptr_list.resize(n_accepted); | ||
TpsiM_col_acc_deviceptr_list.resize(n_accepted); | ||
psiM_row_acc_deviceptr_list.resize(n_accepted); | ||
new_ratios_to_ref_acc_deviceptr_list.resize(n_accepted); | ||
ratios_to_ref_acc_deviceptr_list.resize(n_accepted); | ||
for (int i = 0; i < n_accepted; i++) | ||
{ | ||
auto iacc = idx_Accepted[i]; | ||
psiMinv_temp_acc_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[iacc]; | ||
psiMinv_acc_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[iacc]; | ||
psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc]; | ||
TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex; | ||
psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb; | ||
|
||
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc); | ||
new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data(); | ||
ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data(); | ||
} | ||
} | ||
|
||
// copy data for accepted walkers | ||
switch (wfc_leader.UpdateMode) | ||
{ | ||
case ORB_PBYP_RATIO: | ||
/** | ||
* psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] | ||
* psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim) | ||
* psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) | ||
* new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] | ||
*/ | ||
|
||
ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, | ||
psiMinv_acc_deviceptr_list.data(), 1, n_accepted); | ||
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel, | ||
n_accepted); | ||
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, | ||
n_accepted); | ||
ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, | ||
ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); | ||
break; | ||
case ORB_PBYP_PARTIAL: | ||
/** | ||
* psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] | ||
* psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim) | ||
* new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] | ||
* psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) | ||
* dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType | ||
* d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) | ||
* if (is_spinor_) | ||
* dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) | ||
*/ | ||
ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, | ||
psiMinv_acc_deviceptr_list.data(), 1, n_accepted); | ||
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel, | ||
n_accepted); | ||
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, | ||
n_accepted); | ||
ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), | ||
1, n_accepted); | ||
ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, | ||
n_accepted); | ||
ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, | ||
ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); | ||
|
||
// dspin_psiM/V not on device | ||
if (wfc_leader.is_spinor_) | ||
for (int i = 0; i < n_accepted; i++) | ||
BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); | ||
// ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); | ||
|
||
break; | ||
default: | ||
/** | ||
* psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls] | ||
* psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim) | ||
* new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets] | ||
* new_grads[:,:] -> grads[:,:]; [NumDets,NumPtcls] GradType | ||
* new_lapls[:,:] -> lapls[:,:]; [NumDets,NumPtcls] | ||
* psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) | ||
* dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType | ||
* d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) | ||
* if (is_spinor_) | ||
* dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) | ||
* new_spingrads[:,:] -> spingrads[:,:]; [NumDets,NumPtcls] | ||
*/ | ||
ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1, | ||
psiMinv_acc_deviceptr_list.data(), 1, n_accepted); | ||
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), norb, | ||
n_accepted); | ||
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1, | ||
n_accepted); | ||
ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(), | ||
1, n_accepted); | ||
ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1, | ||
n_accepted); | ||
ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1, | ||
ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted); | ||
|
||
// grads,lapls not on device | ||
// ompBLAS::copy_batched(handle, ndet * nel * DIM, new_grads_acc_ptr_list.data(), 1, grads_acc_ptr_list.data(), 1, n_accepted); | ||
// ompBLAS::copy_batched(handle, ndet * nel, new_lapls_acc_ptr_list.data(), 1, lapls_acc_ptr_list.data(), 1, n_accepted); | ||
for (int i = 0; i < n_accepted; i++) | ||
{ | ||
BLAS::copy(ndet * nel * DIM, new_grads_acc_ptr_list[i], 1, grads_acc_ptr_list[i], 1); | ||
BLAS::copy(ndet * nel, new_lapls_acc_ptr_list[i], 1, lapls_acc_ptr_list[i], 1); | ||
} | ||
|
||
if (wfc_leader.is_spinor_) | ||
{ | ||
// ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted); | ||
// ompBLAS::copy_batched(handle, ndet * nel, new_spingrads_acc_ptr_list.data(), 1, spingrads_acc_ptr_list.data(), 1, n_accepted); | ||
for (int i = 0; i < n_accepted; i++) | ||
{ | ||
BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1); | ||
BLAS::copy(ndet * nel, new_spingrads_acc_ptr_list[i], 1, spingrads_acc_ptr_list[i], 1); | ||
} | ||
} | ||
|
||
break; | ||
} | ||
|
||
// restore: | ||
// setup pointer lists | ||
OffloadVector<ValueType*> psiMinv_temp_rej_deviceptr_list(n_rejected); | ||
OffloadVector<ValueType*> psiMinv_rej_deviceptr_list(n_rejected); | ||
OffloadVector<ValueType*> TpsiM_col_rej_deviceptr_list(n_rejected); | ||
OffloadVector<ValueType*> psiM_row_rej_deviceptr_list(n_rejected); | ||
for (int i = 0; i < n_rejected; i++) | ||
{ | ||
auto irej = idx_Rejected[i]; | ||
psiMinv_temp_rej_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[irej]; | ||
psiMinv_rej_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[irej]; | ||
TpsiM_col_rej_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[irej] + WorkingIndex; | ||
psiM_row_rej_deviceptr_list[i] = mw_res.psiM_deviceptr_list[irej] + WorkingIndex * norb; | ||
} | ||
|
||
/** | ||
* psiMinv[:,:] -> psiMinv_temp[:,:]; [NumPtcls,NumPtcls] | ||
* psiM[WorkingIndex,:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in other dim) | ||
*/ | ||
ompBLAS::copy_batched(handle, nel * nel, psiMinv_rej_deviceptr_list.data(), 1, psiMinv_temp_rej_deviceptr_list.data(), | ||
1, n_rejected); | ||
ompBLAS::copy_batched(handle, norb, psiM_row_rej_deviceptr_list.data(), 1, TpsiM_col_rej_deviceptr_list.data(), nel, | ||
n_rejected); | ||
|
||
for (auto& irej : idx_Rejected) | ||
{ | ||
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(irej); | ||
wfc.curRatio = ValueType(1); | ||
} | ||
} | ||
|
||
// this has been fixed | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please have this follow the common idiom of
This exceptional looking code lead me and likely other carteful readers to wonder just why
isAccepted
could be longer than wfc_list. Also you are tossing the implicit bounds checking that using the begin() end() iterator pair provides.mw functions "multi walker" arguments must always be the same size you can and should write the code as if there were true. If you have doubts during development use an
assert
definitely don't write code that "defends" against this defect.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, this makes sense; sometimes I try too hard to replicate exactly what the prior logic was even when (like here) that results in some ugly, non-standard code