Skip to content

Commit

Permalink
Merge pull request #5071 from kgasperich/md-mw-accept-reject
Browse files Browse the repository at this point in the history
batching in MultiDiracDeterminant::mw_accept_rejectMove
  • Loading branch information
prckent authored Aug 23, 2024
2 parents bca6887 + 8a2f6ae commit 9a4dd3b
Showing 1 changed file with 353 additions and 4 deletions.
357 changes: 353 additions & 4 deletions src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "Message/Communicate.h"
#include "Numerics/DeterminantOperators.h"
#include "CPU/BLAS.hpp"
#include "OMPTarget/ompBLAS.hpp"
#include "Numerics/MatrixOperators.h"
#include <algorithm>
#include <vector>
Expand Down Expand Up @@ -529,12 +530,360 @@ void MultiDiracDeterminant::mw_accept_rejectMove(const RefVectorWithLeader<Multi
int iat,
const std::vector<bool>& isAccepted)
{
// TODO to be expanded to serve offload needs without relying on calling acceptMove and restore
for (int iw = 0; iw < wfc_list.size(); iw++)
const int nw = wfc_list.size();
assert(isAccepted.size() == nw);
// separate accepted/rejected walker indices
const int n_accepted = std::count(isAccepted.begin(), isAccepted.end(), true);
const int n_rejected = nw - n_accepted;

//TODO: can put these in some preallocated work space (reserve up to n_walkers)
std::vector<int> idx_Accepted(n_accepted);
std::vector<int> idx_Rejected(n_rejected);

// create lists of accepted/rejected walker indices
for (int iw = 0, iacc = 0, irej = 0; iw < nw; iw++)
if (isAccepted[iw])
wfc_list[iw].acceptMove(p_list[iw], iat, false);
idx_Accepted[iacc++] = iw;
else
wfc_list[iw].restore(iat);
idx_Rejected[irej++] = iw;


MultiDiracDeterminant& wfc_leader = wfc_list.getLeader();
ParticleSet& p_leader = p_list.getLeader();
const int ndet = wfc_leader.getNumDets();
const int norb = wfc_leader.NumOrbitals;
const int nel = wfc_leader.NumPtcls;
auto& mw_res = wfc_leader.mw_res_handle_.getResource();

const int WorkingIndex = iat - wfc_leader.FirstIndex;
assert(WorkingIndex >= 0 && WorkingIndex < wfc_leader.LastIndex - wfc_leader.FirstIndex);
assert(p_leader.isSpinor() == wfc_leader.is_spinor_);
int handle = 0;
for (auto& iacc : idx_Accepted)
{
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
if (wfc.curRatio == ValueType(0))
{
std::ostringstream msg;
msg << "MultiDiracDeterminant::acceptMove curRatio is " << wfc.curRatio << " for walker " << iacc
<< "! Report a bug." << std::endl;
throw std::runtime_error(msg.str());
}
}

for (auto& iacc : idx_Accepted)
{
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
wfc.log_value_ref_det_ += convertValueToLog(wfc.curRatio);
wfc.curRatio = ValueType(1);
}

// copy data for accepted walkers
switch (wfc_leader.UpdateMode)
{
case ORB_PBYP_RATIO:
/**
* psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls]
* psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim)
* psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets]
*/
{
Vector<ValueType*> psiMinv_temp_acc_ptr_list(n_accepted);
Vector<ValueType*> psiMinv_acc_ptr_list(n_accepted);

Vector<ValueType*> psiV_acc_ptr_list(n_accepted);
Vector<ValueType*> TpsiM_col_acc_ptr_list(n_accepted);
Vector<ValueType*> psiM_row_acc_ptr_list(n_accepted);

Vector<ValueType*> new_ratios_to_ref_acc_ptr_list(n_accepted);
Vector<ValueType*> ratios_to_ref_acc_ptr_list(n_accepted);

for (int i = 0; i < n_accepted; i++)
{
auto iacc = idx_Accepted[i];
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);

psiMinv_temp_acc_ptr_list[i] = wfc.psiMinv_temp.data();
psiMinv_acc_ptr_list[i] = wfc.psiMinv.data();

psiV_acc_ptr_list[i] = wfc.psiV.data();
TpsiM_col_acc_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex;
psiM_row_acc_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb;

new_ratios_to_ref_acc_ptr_list[i] = wfc.new_ratios_to_ref_.data();
ratios_to_ref_acc_ptr_list[i] = wfc.ratios_to_ref_.data();
}
for (int i = 0; i < n_accepted; i++)
{
BLAS::copy(nel * nel, psiMinv_temp_acc_ptr_list[i], 1, psiMinv_acc_ptr_list[i], 1);
BLAS::copy(norb, psiV_acc_ptr_list[i], 1, TpsiM_col_acc_ptr_list[i], nel);
BLAS::copy(norb, psiV_acc_ptr_list[i], 1, psiM_row_acc_ptr_list[i], 1);
BLAS::copy(ndet, new_ratios_to_ref_acc_ptr_list[i], 1, ratios_to_ref_acc_ptr_list[i], 1);
}
// ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_ptr_list.data(), 1, psiMinv_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, TpsiM_col_acc_ptr_list.data(), nel, n_accepted);
// ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, psiM_row_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_ptr_list.data(), 1, ratios_to_ref_acc_ptr_list.data(), 1, n_accepted);
for (auto& iacc : idx_Accepted)
{
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
wfc.ratios_to_ref_.updateTo();
wfc.TpsiM.updateTo();
wfc.psiMinv.updateTo();
wfc.psiM.updateTo();
// dpsiM not updated on host in this case, but this H2D update is included for consistency with single-walker acceptMove
wfc.dpsiM.updateTo();
}
}
break;

case ORB_PBYP_PARTIAL:
/**
* psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls]
* psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim)
* new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets]
* psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType
* d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* if (is_spinor_)
* dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
*/
{
Vector<ValueType*> psiMinv_temp_acc_ptr_list(n_accepted);
Vector<ValueType*> psiMinv_acc_ptr_list(n_accepted);

Vector<ValueType*> psiV_acc_ptr_list(n_accepted);
Vector<ValueType*> TpsiM_col_acc_ptr_list(n_accepted);
Vector<ValueType*> psiM_row_acc_ptr_list(n_accepted);

Vector<ValueType*> new_ratios_to_ref_acc_ptr_list(n_accepted);
Vector<ValueType*> ratios_to_ref_acc_ptr_list(n_accepted);

Vector<ValueType*> dpsiV_acc_ptr_list(n_accepted);
Vector<ValueType*> dpsiM_row_acc_ptr_list(n_accepted);
Vector<ValueType*> d2psiV_acc_ptr_list(n_accepted);
Vector<ValueType*> d2psiM_row_acc_ptr_list(n_accepted);

for (int i = 0; i < n_accepted; i++)
{
auto iacc = idx_Accepted[i];
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);

psiMinv_temp_acc_ptr_list[i] = wfc.psiMinv_temp.data();
psiMinv_acc_ptr_list[i] = wfc.psiMinv.data();

psiV_acc_ptr_list[i] = wfc.psiV.data();
TpsiM_col_acc_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex;
psiM_row_acc_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb;

new_ratios_to_ref_acc_ptr_list[i] = wfc.new_ratios_to_ref_.data();
ratios_to_ref_acc_ptr_list[i] = wfc.ratios_to_ref_.data();

dpsiV_acc_ptr_list[i] = wfc.dpsiV.data()->data();
dpsiM_row_acc_ptr_list[i] = wfc.dpsiM.data()->data() + WorkingIndex * norb * DIM;
d2psiV_acc_ptr_list[i] = wfc.d2psiV.data();
d2psiM_row_acc_ptr_list[i] = wfc.d2psiM.data() + WorkingIndex * norb;
}
for (int i = 0; i < n_accepted; i++)
{
BLAS::copy(nel * nel, psiMinv_temp_acc_ptr_list[i], 1, psiMinv_acc_ptr_list[i], 1);
BLAS::copy(norb, psiV_acc_ptr_list[i], 1, TpsiM_col_acc_ptr_list[i], nel);
BLAS::copy(norb, psiV_acc_ptr_list[i], 1, psiM_row_acc_ptr_list[i], 1);
BLAS::copy(norb * DIM, dpsiV_acc_ptr_list[i], 1, dpsiM_row_acc_ptr_list[i], 1);
BLAS::copy(norb, d2psiV_acc_ptr_list[i], 1, d2psiM_row_acc_ptr_list[i], 1);
BLAS::copy(ndet, new_ratios_to_ref_acc_ptr_list[i], 1, ratios_to_ref_acc_ptr_list[i], 1);
}
// ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_ptr_list.data(), 1, psiMinv_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, TpsiM_col_acc_ptr_list.data(), nel, n_accepted);
// ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, psiM_row_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_ptr_list.data(), 1, dpsiM_row_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, norb, d2psiV_acc_ptr_list.data(), 1, d2psiM_row_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_ptr_list.data(), 1, ratios_to_ref_acc_ptr_list.data(), 1, n_accepted);

// dspin_psiM/V not on device
if (wfc_leader.is_spinor_)
{
Vector<ValueType*> dspin_psiV_acc_ptr_list(n_accepted);
Vector<ValueType*> dspin_psiM_row_acc_ptr_list(n_accepted);
for (int i = 0; i < n_accepted; i++)
{
auto iacc = idx_Accepted[i];
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data();
dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb;
}
for (int i = 0; i < n_accepted; i++)
BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1);
// ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted);
}
for (auto& iacc : idx_Accepted)
{
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
wfc.ratios_to_ref_.updateTo();
wfc.TpsiM.updateTo();
wfc.psiMinv.updateTo();
wfc.psiM.updateTo();
wfc.dpsiM.updateTo();
}
}
break;

default:
/**
* psiMinv_temp[:,:] -> psiMinv[:,:]; [NumPtcls,NumPtcls]
* psiV[:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in 2nd dim)
* new_ratios_to_ref_[:] -> ratios_to_ref_[:]; [NumDets]
* new_grads[:,:] -> grads[:,:]; [NumDets,NumPtcls] GradType
* new_lapls[:,:] -> lapls[:,:]; [NumDets,NumPtcls]
* psiV[:] -> psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* dpsiV[:] -> dpsiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim) GradType
* d2psiV[:] -> d2psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* if (is_spinor_)
* dspin_psiV[:] -> dspin_psiM[WorkingIndex,:]; [NumOrbitals] (NumPtcls in 1st dim)
* new_spingrads[:,:] -> spingrads[:,:]; [NumDets,NumPtcls]
*/
{
Vector<ValueType*> psiMinv_temp_acc_ptr_list(n_accepted);
Vector<ValueType*> psiMinv_acc_ptr_list(n_accepted);

Vector<ValueType*> psiV_acc_ptr_list(n_accepted);
Vector<ValueType*> TpsiM_col_acc_ptr_list(n_accepted);
Vector<ValueType*> psiM_row_acc_ptr_list(n_accepted);

Vector<ValueType*> new_ratios_to_ref_acc_ptr_list(n_accepted);
Vector<ValueType*> ratios_to_ref_acc_ptr_list(n_accepted);

Vector<ValueType*> dpsiV_acc_ptr_list(n_accepted);
Vector<ValueType*> dpsiM_row_acc_ptr_list(n_accepted);
Vector<ValueType*> d2psiV_acc_ptr_list(n_accepted);
Vector<ValueType*> d2psiM_row_acc_ptr_list(n_accepted);

Vector<ValueType*> new_grads_acc_ptr_list(n_accepted);
Vector<ValueType*> grads_acc_ptr_list(n_accepted);
Vector<ValueType*> new_lapls_acc_ptr_list(n_accepted);
Vector<ValueType*> lapls_acc_ptr_list(n_accepted);

for (int i = 0; i < n_accepted; i++)
{
auto iacc = idx_Accepted[i];
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);

psiMinv_temp_acc_ptr_list[i] = wfc.psiMinv_temp.data();
psiMinv_acc_ptr_list[i] = wfc.psiMinv.data();

psiV_acc_ptr_list[i] = wfc.psiV.data();
TpsiM_col_acc_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex;
psiM_row_acc_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb;

new_ratios_to_ref_acc_ptr_list[i] = wfc.new_ratios_to_ref_.data();
ratios_to_ref_acc_ptr_list[i] = wfc.ratios_to_ref_.data();

dpsiV_acc_ptr_list[i] = wfc.dpsiV.data()->data();
dpsiM_row_acc_ptr_list[i] = wfc.dpsiM.data()->data() + WorkingIndex * norb * DIM;
d2psiV_acc_ptr_list[i] = wfc.d2psiV.data();
d2psiM_row_acc_ptr_list[i] = wfc.d2psiM.data() + WorkingIndex * norb;

new_grads_acc_ptr_list[i] = wfc.new_grads.data()->data();
grads_acc_ptr_list[i] = wfc.grads.data()->data();
new_lapls_acc_ptr_list[i] = wfc.new_lapls.data();
lapls_acc_ptr_list[i] = wfc.lapls.data();
}
for (int i = 0; i < n_accepted; i++)
{
BLAS::copy(nel * nel, psiMinv_temp_acc_ptr_list[i], 1, psiMinv_acc_ptr_list[i], 1);
BLAS::copy(norb, psiV_acc_ptr_list[i], 1, TpsiM_col_acc_ptr_list[i], nel);
BLAS::copy(norb, psiV_acc_ptr_list[i], 1, psiM_row_acc_ptr_list[i], 1);
BLAS::copy(norb * DIM, dpsiV_acc_ptr_list[i], 1, dpsiM_row_acc_ptr_list[i], 1);
BLAS::copy(norb, d2psiV_acc_ptr_list[i], 1, d2psiM_row_acc_ptr_list[i], 1);
BLAS::copy(ndet, new_ratios_to_ref_acc_ptr_list[i], 1, ratios_to_ref_acc_ptr_list[i], 1);
}
// ompBLAS::copy_batched(handle, nel * nel, psiMinv_temp_acc_ptr_list.data(), 1, psiMinv_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, TpsiM_col_acc_ptr_list.data(), nel, n_accepted);
// ompBLAS::copy_batched(handle, norb, psiV_acc_ptr_list.data(), 1, psiM_row_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, norb * DIM, dpsiV_acc_ptr_list.data(), 1, dpsiM_row_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, norb, d2psiV_acc_ptr_list.data(), 1, d2psiM_row_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, ndet, new_ratios_to_ref_acc_ptr_list.data(), 1, ratios_to_ref_acc_ptr_list.data(), 1, n_accepted);

// grads,lapls not on device
// ompBLAS::copy_batched(handle, ndet * nel * DIM, new_grads_acc_ptr_list.data(), 1, grads_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, ndet * nel, new_lapls_acc_ptr_list.data(), 1, lapls_acc_ptr_list.data(), 1, n_accepted);
for (int i = 0; i < n_accepted; i++)
{
BLAS::copy(ndet * nel * DIM, new_grads_acc_ptr_list[i], 1, grads_acc_ptr_list[i], 1);
BLAS::copy(ndet * nel, new_lapls_acc_ptr_list[i], 1, lapls_acc_ptr_list[i], 1);
}

if (wfc_leader.is_spinor_)
{
Vector<ValueType*> dspin_psiV_acc_ptr_list(n_accepted);
Vector<ValueType*> dspin_psiM_row_acc_ptr_list(n_accepted);
Vector<ValueType*> new_spingrads_acc_ptr_list(n_accepted);
Vector<ValueType*> spingrads_acc_ptr_list(n_accepted);

for (int i = 0; i < n_accepted; i++)
{
auto iacc = idx_Accepted[i];
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
dspin_psiV_acc_ptr_list[i] = wfc.dspin_psiV.data();
dspin_psiM_row_acc_ptr_list[i] = wfc.dspin_psiM.data() + WorkingIndex * norb;
new_spingrads_acc_ptr_list[i] = wfc.new_spingrads.data();
spingrads_acc_ptr_list[i] = wfc.spingrads.data();
}
for (int i = 0; i < n_accepted; i++)
{
BLAS::copy(norb, dspin_psiV_acc_ptr_list[i], 1, dspin_psiM_row_acc_ptr_list[i], 1);
BLAS::copy(ndet * nel, new_spingrads_acc_ptr_list[i], 1, spingrads_acc_ptr_list[i], 1);
// ompBLAS::copy_batched(handle, norb, dspin_psiV_acc_ptr_list.data(), 1, dspin_psiM_row_acc_ptr_list.data(), 1, n_accepted);
// ompBLAS::copy_batched(handle, ndet * nel, new_spingrads_acc_ptr_list.data(), 1, spingrads_acc_ptr_list.data(), 1, n_accepted);
}
}
for (auto& iacc : idx_Accepted)
{
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(iacc);
wfc.ratios_to_ref_.updateTo();
wfc.TpsiM.updateTo();
wfc.psiMinv.updateTo();
wfc.psiM.updateTo();
wfc.dpsiM.updateTo();
}
}
break;
}
// restore:
// setup pointer lists
Vector<ValueType*> psiMinv_temp_rej_ptr_list(n_rejected);
Vector<ValueType*> psiMinv_rej_ptr_list(n_rejected);
Vector<ValueType*> TpsiM_col_rej_ptr_list(n_rejected);
Vector<ValueType*> psiM_row_rej_ptr_list(n_rejected);
for (int i = 0; i < n_rejected; i++)
{
auto irej = idx_Rejected[i];
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(irej);
psiMinv_temp_rej_ptr_list[i] = wfc.psiMinv_temp.data();
psiMinv_rej_ptr_list[i] = wfc.psiMinv.data();
TpsiM_col_rej_ptr_list[i] = wfc.TpsiM.data() + WorkingIndex;
psiM_row_rej_ptr_list[i] = wfc.psiM.data() + WorkingIndex * norb;
}

/**
* psiMinv[:,:] -> psiMinv_temp[:,:]; [NumPtcls,NumPtcls]
* psiM[WorkingIndex,:] -> TpsiM[:,WorkingIndex]; [NumOrbitals] (NumPtcls in other dim)
*/
for (int i = 0; i < n_rejected; i++)
{
BLAS::copy(nel * nel, psiMinv_rej_ptr_list[i], 1, psiMinv_temp_rej_ptr_list[i], 1);
BLAS::copy(norb, psiM_row_rej_ptr_list[i], 1, TpsiM_col_rej_ptr_list[i], nel);
}
// ompBLAS::copy_batched(handle, nel * nel, psiMinv_rej_ptr_list.data(), 1, psiMinv_temp_rej_ptr_list.data(), 1, n_rejected);
// ompBLAS::copy_batched(handle, norb, psiM_row_rej_ptr_list.data(), 1, TpsiM_col_rej_ptr_list.data(), nel, n_rejected);

for (auto& irej : idx_Rejected)
{
auto& wfc = wfc_list.getCastedElement<MultiDiracDeterminant>(irej);
wfc.curRatio = ValueType(1);
wfc.TpsiM.updateTo();
}
}

// this has been fixed
Expand Down

0 comments on commit 9a4dd3b

Please sign in to comment.