Skip to content

Commit

Permalink
Add new bits to RotatedSPOsT
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipFackler authored and williamfgc committed Sep 28, 2023
1 parent 81858bb commit 90010b4
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 0 deletions.
146 changes: 146 additions & 0 deletions src/QMCWaveFunctions/RotatedSPOsT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,152 @@ RotatedSPOsT<T>::makeClone() const
return myclone;
}

template <typename T>
void
RotatedSPOsT<T>::mw_evaluateDetRatios(
const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<const VirtualParticleSetT<T>>& vp_list,
const RefVector<ValueVector>& psi_list,
const std::vector<const ValueType*>& invRow_ptr_list,
std::vector<std::vector<ValueType>>& ratios_list) const
{
auto phi_list = extractPhiRefList(spo_list);
auto& leader = phi_list.getLeader();
leader.mw_evaluateDetRatios(
phi_list, vp_list, psi_list, invRow_ptr_list, ratios_list);
}

template <typename T>
void
RotatedSPOsT<T>::mw_evaluateValue(
const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const RefVector<ValueVector>& psi_v_list) const
{
auto phi_list = extractPhiRefList(spo_list);
auto& leader = phi_list.getLeader();
leader.mw_evaluateValue(phi_list, P_list, iat, psi_v_list);
}

template <typename T>
void
RotatedSPOsT<T>::mw_evaluateVGL(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const RefVector<ValueVector>& psi_v_list,
const RefVector<GradVector>& dpsi_v_list,
const RefVector<ValueVector>& d2psi_v_list) const
{
auto phi_list = extractPhiRefList(spo_list);
auto& leader = phi_list.getLeader();
leader.mw_evaluateVGL(
phi_list, P_list, iat, psi_v_list, dpsi_v_list, d2psi_v_list);
}

template <typename T>
void
RotatedSPOsT<T>::mw_evaluateVGLWithSpin(
const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const RefVector<ValueVector>& psi_v_list,
const RefVector<GradVector>& dpsi_v_list,
const RefVector<ValueVector>& d2psi_v_list,
OffloadMatrix<ComplexType>& mw_dspin) const
{
auto phi_list = extractPhiRefList(spo_list);
auto& leader = phi_list.getLeader();
leader.mw_evaluateVGLWithSpin(
phi_list, P_list, iat, psi_v_list, dpsi_v_list, d2psi_v_list, mw_dspin);
}

template <typename T>
void
RotatedSPOsT<T>::mw_evaluateVGLandDetRatioGrads(
const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const std::vector<const ValueType*>& invRow_ptr_list,
OffloadMWVGLArray& phi_vgl_v, std::vector<ValueType>& ratios,
std::vector<GradType>& grads) const
{
auto phi_list = extractPhiRefList(spo_list);
auto& leader = phi_list.getLeader();
leader.mw_evaluateVGLandDetRatioGrads(
phi_list, P_list, iat, invRow_ptr_list, phi_vgl_v, ratios, grads);
}

template <typename T>
void
RotatedSPOsT<T>::mw_evaluateVGLandDetRatioGradsWithSpin(
const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const std::vector<const ValueType*>& invRow_ptr_list,
OffloadMWVGLArray& phi_vgl_v, std::vector<ValueType>& ratios,
std::vector<GradType>& grads, std::vector<ValueType>& spingrads) const
{
auto phi_list = extractPhiRefList(spo_list);
auto& leader = phi_list.getLeader();
leader.mw_evaluateVGLandDetRatioGradsWithSpin(phi_list, P_list, iat,
invRow_ptr_list, phi_vgl_v, ratios, grads, spingrads);
}

template <typename T>
void
RotatedSPOsT<T>::mw_evaluate_notranspose(
const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int first, int last,
const RefVector<ValueMatrix>& logdet_list,
const RefVector<GradMatrix>& dlogdet_list,
const RefVector<ValueMatrix>& d2logdet_list) const
{
auto phi_list = extractPhiRefList(spo_list);
auto& leader = phi_list.getLeader();
leader.mw_evaluate_notranspose(phi_list, P_list, first, last, logdet_list,
dlogdet_list, d2logdet_list);
}

template <typename T>
void
RotatedSPOsT<T>::createResource(ResourceCollection& collection) const
{
Phi->createResource(collection);
}

template <typename T>
void
RotatedSPOsT<T>::acquireResource(ResourceCollection& collection,
const RefVectorWithLeader<SPOSetT<T>>& spo_list) const
{
auto phi_list = extractPhiRefList(spo_list);
auto& leader = phi_list.getLeader();
leader.acquireResource(collection, phi_list);
}

template <typename T>
void
RotatedSPOsT<T>::releaseResource(ResourceCollection& collection,
const RefVectorWithLeader<SPOSetT<T>>& spo_list) const
{
auto phi_list = extractPhiRefList(spo_list);
auto& leader = phi_list.getLeader();
leader.releaseResource(collection, phi_list);
}

template <typename T>
RefVectorWithLeader<SPOSetT<T>>
RotatedSPOsT<T>::extractPhiRefList(
const RefVectorWithLeader<SPOSetT<T>>& spo_list)
{
auto& spo_leader = spo_list.template getCastedLeader<RotatedSPOsT>();
const auto nw = spo_list.size();
RefVectorWithLeader<SPOSetT<T>> phi_list(*spo_leader.Phi);
phi_list.reserve(nw);
for (int iw = 0; iw < nw; iw++) {
RotatedSPOsT& rot =
spo_list.template getCastedElement<RotatedSPOsT>(iw);
phi_list.emplace_back(*rot.Phi);
}
return phi_list;
}

// Class concrete types from ValueType
template class RotatedSPOsT<double>;
template class RotatedSPOsT<float>;
Expand Down
70 changes: 70 additions & 0 deletions src/QMCWaveFunctions/RotatedSPOsT.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObjectT<T>
using IndexType = typename SPOSetT<T>::IndexType;
using RealType = typename SPOSetT<T>::RealType;
using ValueType = typename SPOSetT<T>::ValueType;
using GradType = typename SPOSetT<T>::GradType;
using ComplexType = typename SPOSetT<T>::ComplexType;
using FullRealType = typename SPOSetT<T>::FullRealType;
using ValueVector = typename SPOSetT<T>::ValueVector;
using ValueMatrix = typename SPOSetT<T>::ValueMatrix;
Expand All @@ -49,6 +51,9 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObjectT<T>
using HessMatrix = typename SPOSetT<T>::HessMatrix;
using GGGVector = typename SPOSetT<T>::GGGVector;
using GGGMatrix = typename SPOSetT<T>::GGGMatrix;
using OffloadMWVGLArray = typename SPOSetT<T>::OffloadMWVGLArray;
template <typename DT>
using OffloadMatrix = Matrix<DT, OffloadPinnedAllocator<DT>>;

// constructor
RotatedSPOsT(
Expand Down Expand Up @@ -399,6 +404,68 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObjectT<T>
use_global_rot_ = use_global_rotation;
}

void
mw_evaluateDetRatios(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<const VirtualParticleSetT<T>>& vp_list,
const RefVector<ValueVector>& psi_list,
const std::vector<const ValueType*>& invRow_ptr_list,
std::vector<std::vector<ValueType>>& ratios_list) const override;

void
mw_evaluateValue(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const RefVector<ValueVector>& psi_v_list) const override;

void
mw_evaluateVGL(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const RefVector<ValueVector>& psi_v_list,
const RefVector<GradVector>& dpsi_v_list,
const RefVector<ValueVector>& d2psi_v_list) const override;

void
mw_evaluateVGLWithSpin(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const RefVector<ValueVector>& psi_v_list,
const RefVector<GradVector>& dpsi_v_list,
const RefVector<ValueVector>& d2psi_v_list,
OffloadMatrix<ComplexType>& mw_dspin) const override;

void
mw_evaluateVGLandDetRatioGrads(
const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const std::vector<const ValueType*>& invRow_ptr_list,
OffloadMWVGLArray& phi_vgl_v, std::vector<ValueType>& ratios,
std::vector<GradType>& grads) const override;

void
mw_evaluateVGLandDetRatioGradsWithSpin(
const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int iat,
const std::vector<const ValueType*>& invRow_ptr_list,
OffloadMWVGLArray& phi_vgl_v, std::vector<ValueType>& ratios,
std::vector<GradType>& grads,
std::vector<ValueType>& spingrads) const override;

void
mw_evaluate_notranspose(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list, int first, int last,
const RefVector<ValueMatrix>& logdet_list,
const RefVector<GradMatrix>& dlogdet_list,
const RefVector<ValueMatrix>& d2logdet_list) const override;

void
createResource(ResourceCollection& collection) const override;

void
acquireResource(ResourceCollection& collection,
const RefVectorWithLeader<SPOSetT<T>>& spo_list) const override;

void
releaseResource(ResourceCollection& collection,
const RefVectorWithLeader<SPOSetT<T>>& spo_list) const override;

private:
/// true if SPO parameters (orbital rotation parameters) have been supplied
/// by input
Expand All @@ -415,6 +482,9 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObjectT<T>
/// Use global rotation or history list
bool use_global_rot_ = true;

static RefVectorWithLeader<SPOSetT<T>>
extractPhiRefList(const RefVectorWithLeader<SPOSetT<T>>& spo_list);

friend OptVariablesType<double>&
testing::getMyVarsFull(RotatedSPOsT<double>& rot);
friend OptVariablesType<float>&
Expand Down

0 comments on commit 90010b4

Please sign in to comment.