Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batching in MultiDiracDeterminant::mw_accept_rejectMove #5071

Merged
merged 9 commits into from
Aug 23, 2024
281 changes: 277 additions & 4 deletions src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "Message/Communicate.h"
#include "Numerics/DeterminantOperators.h"
#include "CPU/BLAS.hpp"
#include "OMPTarget/ompBLAS.hpp"
#include "Numerics/MatrixOperators.h"
#include <algorithm>
#include <vector>
Expand Down Expand Up @@ -529,12 +530,284 @@
int iat,
const std::vector<bool>& isAccepted)
{
// TODO to be expanded to serve offload needs without relying on calling acceptMove and restore
for (int iw = 0; iw < wfc_list.size(); iw++)
const int nw = wfc_list.size();
// separate accepted/rejected walker indices
const int n_accepted = std::count(isAccepted.begin(), isAccepted.begin() + nw, true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please have this follow the common idiom of

std::count(isAccepted.begin(), isAccepted.end(), true)

This exceptional looking code lead me and likely other carteful readers to wonder just why isAccepted could be longer than wfc_list. Also you are tossing the implicit bounds checking that using the begin() end() iterator pair provides.

mw functions "multi walker" arguments must always be the same size you can and should write the code as if there were true. If you have doubts during development use an assert definitely don't write code that "defends" against this defect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this makes sense; sometimes I try too hard to replicate exactly what the prior logic was even when (like here) that results in some ugly, non-standard code

const int n_rejected = nw - n_accepted;

//TODO: can put these in some preallocated work space (reserve up to n_walkers)
std::vector<int> idx_Accepted(n_accepted);
std::vector<int> idx_Rejected(n_rejected);

for (int iw = 0, iacc = 0, irej = 0; iw < nw; iw++)
if (isAccepted[iw])
wfc_list[iw].acceptMove(p_list[iw], iat, false);
idx_Accepted[iacc++] = iw;
else
wfc_list[iw].restore(iat);
idx_Rejected[irej++] = iw;


MultiDiracDeterminant& wfc_leader = wfc_list.getLeader();
ParticleSet& p_leader = p_list.getLeader();
const int ndet = wfc_leader.getNumDets();
const int norb = wfc_leader.NumOrbitals;
const int nel = wfc_leader.NumPtcls;
auto& mw_res = wfc_leader.mw_res_handle_.getResource();

const int WorkingIndex = iat - wfc_leader.FirstIndex;
assert(WorkingIndex >= 0 && WorkingIndex < wfc_leader.LastIndex - wfc_leader.FirstIndex);
assert(p_leader.isSpinor() == wfc_leader.is_spinor_);
int handle = 0;
for (auto& iacc : idx_Accepted)
{
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
if (wfc.curRatio == ValueType(0))
{
std::ostringstream msg;
msg << "MultiDiracDeterminant::acceptMove curRatio is " << wfc.curRatio << " for walker " << iacc
<< "! Report a bug." << std::endl;
throw std::runtime_error(msg.str());
}
}

for (auto& iacc : idx_Accepted)
{
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
wfc.log_value_ref_det_ += convertValueToLog(wfc.curRatio);
wfc.curRatio = ValueType(1);
}

// pointers to data for only accepted walkers
OffloadVector<ValueType*> psiMinv_temp_acc_deviceptr_list;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there aren't retained I'm not sure why the copy and pointer list building cases aren't fused and these are just in those cases. Put blocks in the cases and these lists will be constructed only if you are actually using them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this got a bit messy; I got stuck between not wanting to duplicate code but also not wanting to make the fallthrough logic of the switches hard to follow, so I retained the original logic for the actual copying and used different logic for the pointer lists to avoid duplication (and to only resize the ones I'm actually using), and I ended up with something worse than if I'd just committed to one or the other

OffloadVector<ValueType*> psiMinv_acc_deviceptr_list;
OffloadVector<ValueType*> psiV_acc_deviceptr_list;
OffloadVector<ValueType*> TpsiM_col_acc_deviceptr_list;
OffloadVector<ValueType*> psiM_row_acc_deviceptr_list;
OffloadVector<ValueType*> new_ratios_to_ref_acc_deviceptr_list;
OffloadVector<ValueType*> ratios_to_ref_acc_deviceptr_list;

OffloadVector<ValueType*> dpsiV_acc_deviceptr_list;
OffloadVector<ValueType*> dpsiM_row_acc_deviceptr_list;
OffloadVector<ValueType*> d2psiV_acc_deviceptr_list;
OffloadVector<ValueType*> d2psiM_row_acc_deviceptr_list;

Vector<ValueType*> dspin_psiV_acc_ptr_list;
Vector<ValueType*> dspin_psiM_row_acc_ptr_list;

Vector<ValueType*> new_grads_acc_ptr_list;
Vector<ValueType*> grads_acc_ptr_list;
Vector<ValueType*> new_lapls_acc_ptr_list;
Vector<ValueType*> lapls_acc_ptr_list;
Vector<ValueType*> new_spingrads_acc_ptr_list;
Vector<ValueType*> spingrads_acc_ptr_list;

/**
* some of these are in the mw_resource, and some are not
* for the ones that are, get device pointers from the resource collection
* for the ones that aren't, get pointers from MultiDiracDeterminant object
* TODO: I'm assuming here that all data is already up to date on the device before this function is called
*/

// setup device pointer lists
switch (wfc_leader.UpdateMode)
{
default:

Check warning on line 613 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L613

Added line #L613 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there are no real unit tests for this class and just a massive procedure test for multislater determinant. I'm not sure whether this coverage warning means none of this default case is covered. My reading is the default case is not covered. I think the ORB_PYPB _PARTIAL/RATIOS cases are covered. Seems like a bunch of important code with no coverage.

Before doing a big refactor why not write the missing unit tests. What's the expected outcome of calling this function?

It seems like this function basically just copies a bunch of state data in response to the isAccepted vector and the internal WorkingIndex state. Significantly there are three modes of this.

Fill members with test data, intialize WorkingIndex, call, check side_effects are correct. You will probably need a testing friend class to break the encapsulation. See src/Estimators/OneBodyDensityMatricesTests.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before doing a big refactor why not write the missing unit tests. What's the expected outcome of calling this function?

My plan going into this refactor was that it wasn't going to be a huge change, and that I was just planning to group some similar work together (the rejects and accepts) without any other change in behavior. I don't yet have a complete picture of everything else upstream and downstream from this function and what different cases might need to be handled.

It seems like this function basically just copies a bunch of state data in response to the isAccepted vector and the internal WorkingIndex state.

This is also my take on what's happening, but the part I'm not sure about (and what I asked about in my comment above) is which state data I can assume to be up-to-date on the host/device before entering this function, and what needs to be up-to-date (and where) after exiting.
This is the only point where I'd planned to possibly deviate from the current functionality (e.g. in the ORB_PBYP_RATIO case for accepted moves, we do the relevant copying (host to host) to psiMinv, psiM, TpsiM, and ratios_to_ref_, and then we do an H2D update of all of those in addition to dpsiM. If there are cases where all of these (and psiV, psiMinv_temp, and new_ratios_to_ref_) are already up-to-date on the device, and if anything downstream from here will only use those on the device, then it would make sense to just do the copy on the device).

new_grads_acc_ptr_list.resize(n_accepted);
grads_acc_ptr_list.resize(n_accepted);
new_lapls_acc_ptr_list.resize(n_accepted);
lapls_acc_ptr_list.resize(n_accepted);
if (wfc_leader.is_spinor_)
{
new_spingrads_acc_ptr_list.resize(n_accepted);
spingrads_acc_ptr_list.resize(n_accepted);
}
for (int i = 0; i < n_accepted; i++)
{
auto iacc = idx_Accepted[i];
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
new_grads_acc_ptr_list[i] = wfc.new_grads.data()->data();
grads_acc_ptr_list[i] = wfc.grads.data()->data();
new_lapls_acc_ptr_list[i] = wfc.new_lapls.data();
lapls_acc_ptr_list[i] = wfc.lapls.data();

Check warning on line 630 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L628-L630

Added lines #L628 - L630 were not covered by tests
if (wfc_leader.is_spinor_)
{
new_spingrads_acc_ptr_list[i] = wfc.new_spingrads.data();
spingrads_acc_ptr_list[i] = wfc.spingrads.data();

Check warning on line 634 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L633-L634

Added lines #L633 - L634 were not covered by tests
}
}
case ORB_PBYP_PARTIAL:
dpsiV_acc_deviceptr_list.resize(n_accepted);
dpsiM_row_acc_deviceptr_list.resize(n_accepted);
d2psiV_acc_deviceptr_list.resize(n_accepted);
d2psiM_row_acc_deviceptr_list.resize(n_accepted);
dspin_psiV_acc_ptr_list.resize(n_accepted);
dspin_psiM_row_acc_ptr_list.resize(n_accepted);
for (int i = 0; i < n_accepted; i++)
{
auto iacc = idx_Accepted[i];
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
dpsiV_acc_deviceptr_list[i] = mw_res.dpsiV_deviceptr_list[iacc]->data();
dpsiM_row_acc_deviceptr_list[i] = mw_res.dpsiM_deviceptr_list[iacc]->data() + WorkingIndex * norb * DIM;
d2psiV_acc_deviceptr_list[i] = wfc.d2psiV.device_data();
d2psiM_row_acc_deviceptr_list[i] = wfc.d2psiM.device_data() + WorkingIndex * norb;
if (wfc_leader.is_spinor_)
{
dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data();
dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb;

Check warning on line 655 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L654-L655

Added lines #L654 - L655 were not covered by tests
}
}
case ORB_PBYP_RATIO:
psiMinv_temp_acc_deviceptr_list.resize(n_accepted);
psiMinv_acc_deviceptr_list.resize(n_accepted);
psiV_acc_deviceptr_list.resize(n_accepted);
TpsiM_col_acc_deviceptr_list.resize(n_accepted);
psiM_row_acc_deviceptr_list.resize(n_accepted);
new_ratios_to_ref_acc_deviceptr_list.resize(n_accepted);
ratios_to_ref_acc_deviceptr_list.resize(n_accepted);
for (int i = 0; i < n_accepted; i++)
{
auto iacc = idx_Accepted[i];
psiMinv_temp_acc_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[iacc];
psiMinv_acc_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[iacc];
psiV_acc_deviceptr_list[i] = mw_res.psiV_deviceptr_list[iacc];
TpsiM_col_acc_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[iacc] + WorkingIndex;
psiM_row_acc_deviceptr_list[i] = mw_res.psiM_deviceptr_list[iacc] + WorkingIndex * norb;

auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
new_ratios_to_ref_acc_deviceptr_list[i] = wfc.new_ratios_to_ref_.device_data();
ratios_to_ref_acc_deviceptr_list[i] = wfc.ratios_to_ref_.device_data();
}
}

// copy data for accepted walkers
switch (wfc_leader.UpdateMode)
{
case ORB_PBYP_RATIO:
/**
* psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls]
* psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim)
* psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets]
*/

ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1,
psiMinv_acc_deviceptr_list.data(), 1, n_accepted);
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel,
n_accepted);
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1,
n_accepted);
ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1,
ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted);
break;
case ORB_PBYP_PARTIAL:
/**
* psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls]
* psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim)
* new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets]
* psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType
* d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* if (is_spinor_)
* dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
*/
ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1,
psiMinv_acc_deviceptr_list.data(), 1, n_accepted);
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), nel,
n_accepted);
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1,
n_accepted);
ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(),
1, n_accepted);
ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1,
n_accepted);
ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1,
ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted);

// dspin_psiM/V not on device
if (wfc_leader.is_spinor_)
for (int i = 0; i < n_accepted; i++)
BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1);

Check warning on line 728 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L728

Added line #L728 was not covered by tests
// ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted);

break;
default:

Check warning on line 732 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L732

Added line #L732 was not covered by tests
/**
* psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls]
* psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim)
* new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets]
* new_grads[:,:] -> grads[:,:]; [NumDets,NumPtcls] GradType
* new_lapls[:,:] -> lapls[:,:]; [NumDets,NumPtcls]
* psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType
* d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* if (is_spinor_)
* dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* new_spingrads[:,:] -> spingrads[:,:]; [NumDets,NumPtcls]
*/
ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_deviceptr_list.data(), 1,

Check warning on line 746 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L746

Added line #L746 was not covered by tests
psiMinv_acc_deviceptr_list.data(), 1, n_accepted);
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, TpsiM_col_acc_deviceptr_list.data(), norb,
n_accepted);
ompBLAS::copy_batched(handle, norb, psiV_acc_deviceptr_list.data(), 1, psiM_row_acc_deviceptr_list.data(), 1,
n_accepted);
ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_deviceptr_list.data(), 1, dpsiM_row_acc_deviceptr_list.data(),
1, n_accepted);
ompBLAS::copy_batched(handle, norb, d2psiV_acc_deviceptr_list.data(), 1, d2psiM_row_acc_deviceptr_list.data(), 1,
n_accepted);
ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_deviceptr_list.data(), 1,

Check warning on line 756 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L756

Added line #L756 was not covered by tests
ratios_to_ref_acc_deviceptr_list.data(), 1, n_accepted);

// grads,lapls not on device
// ompBLAS::copy_batched(handle, ndet * nel * DIM, new_grads_acc_ptr_list.data(), 1, grads_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, ndet * nel, new_lapls_acc_ptr_list.data(), 1, lapls_acc_ptr_list.data(), 1, n_accepted);
for (int i = 0; i < n_accepted; i++)
{
BLAS::copy(ndet * nel * DIM, new_grads_acc_ptr_list[i], 1, grads_acc_ptr_list[i], 1);
BLAS::copy(ndet * nel, new_lapls_acc_ptr_list[i], 1, lapls_acc_ptr_list[i], 1);

Check warning on line 765 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L764-L765

Added lines #L764 - L765 were not covered by tests
}

if (wfc_leader.is_spinor_)
{
// ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, ndet * nel, new_spingrads_acc_ptr_list.data(), 1, spingrads_acc_ptr_list.data(), 1, n_accepted);
for (int i = 0; i < n_accepted; i++)
{
BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1);
BLAS::copy(ndet * nel, new_spingrads_acc_ptr_list[i], 1, spingrads_acc_ptr_list[i], 1);

Check warning on line 775 in src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp#L774-L775

Added lines #L774 - L775 were not covered by tests
}
}

break;
}

// restore:
// setup pointer lists
OffloadVector<ValueType*> psiMinv_temp_rej_deviceptr_list(n_rejected);
OffloadVector<ValueType*> psiMinv_rej_deviceptr_list(n_rejected);
OffloadVector<ValueType*> TpsiM_col_rej_deviceptr_list(n_rejected);
OffloadVector<ValueType*> psiM_row_rej_deviceptr_list(n_rejected);
for (int i = 0; i < n_rejected; i++)
{
auto irej = idx_Rejected[i];
psiMinv_temp_rej_deviceptr_list[i] = mw_res.psiMinv_temp_deviceptr_list[irej];
psiMinv_rej_deviceptr_list[i] = mw_res.psiMinv_deviceptr_list[irej];
TpsiM_col_rej_deviceptr_list[i] = mw_res.TpsiM_deviceptr_list[irej] + WorkingIndex;
psiM_row_rej_deviceptr_list[i] = mw_res.psiM_deviceptr_list[irej] + WorkingIndex * norb;
}

/**
* psiMinv[:,:] -> psiMinv_temp[:,:]; [NumPtcls,NumPtcls]
* psiM[WorkingIndex,:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in other dim)
*/
ompBLAS::copy_batched(handle, nel * nel, psiMinv_rej_deviceptr_list.data(), 1, psiMinv_temp_rej_deviceptr_list.data(),
1, n_rejected);
ompBLAS::copy_batched(handle, norb, psiM_row_rej_deviceptr_list.data(), 1, TpsiM_col_rej_deviceptr_list.data(), nel,
n_rejected);

for (auto& irej : idx_Rejected)
{
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(irej);
wfc.curRatio = ValueType(1);
}
}

// this has been fixed
Expand Down
Loading