Skip to content

Commit

Permalink
Merge pull request #5168 from prckent/bettergradcheck
Browse files Browse the repository at this point in the history
Remove erroneous gradient check
  • Loading branch information
ye-luo authored Sep 17, 2024
2 parents 98ae91d + c0fab4a commit 34fb7dd
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 29 deletions.
1 change: 0 additions & 1 deletion src/QMCWaveFunctions/Fermion/DiracDeterminant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ typename DiracDeterminant<DU_TYPE>::GradType DiracDeterminant<DU_TYPE>::evalGrad
invRow_id = WorkingIndex;
updateEng.getInvRow(psiM, WorkingIndex, invRow);
GradType g = simd::dot(invRow.data(), dpsiM[WorkingIndex], invRow.size());
assert(checkG(g));
return g;
}

Expand Down
16 changes: 0 additions & 16 deletions src/QMCWaveFunctions/Fermion/DiracDeterminantBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,22 +189,6 @@ class DiracDeterminantBase : public WaveFunctionComponent
ValueMatrix dummy_vmt;
#endif

static bool checkG(const GradType& g)
{
#if !defined(NDEBUG)
auto g_mag = std::abs(dot(g, g));
if (qmcplusplus::isnan(g_mag))
throw std::runtime_error("gradient of NaN");
if (qmcplusplus::isinf(g_mag))
throw std::runtime_error("gradient of inf");
if (g_mag < std::abs(std::numeric_limits<RealType>::epsilon()))
{
std::cerr << "evalGrad gradient is " << g[0] << ' ' << g[1] << ' ' << g[2] << '\n';
throw std::runtime_error("gradient of zero");
}
#endif
return true;
}
};

} // namespace qmcplusplus
Expand Down
12 changes: 0 additions & 12 deletions src/QMCWaveFunctions/Fermion/DiracDeterminantBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ typename DiracDeterminantBatched<PL, VT, FPVT>::Grad DiracDeterminantBatched<PL,
ScopedTimer local_timer(RatioTimer);
const int WorkingIndex = iat - FirstIndex;
Grad g = simd::dot(psiMinv_[WorkingIndex], dpsiM[WorkingIndex], NumOrbitals);
assert(checkG(g));
return g;
}

Expand Down Expand Up @@ -209,11 +208,6 @@ void DiracDeterminantBatched<PL, VT, FPVT>::mw_evalGrad(const RefVectorWithLeade

UpdateEngine::mw_evalGrad(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc, mw_res.psiMinv_refs,
dpsiM_row_list, WorkingIndex, grad_now);

#ifndef NDEBUG
for (int iw = 0; iw < nw; iw++)
checkG(grad_now[iw]);
#endif
}

template<PlatformKind PL, typename VT, typename FPVT>
Expand All @@ -227,7 +221,6 @@ typename DiracDeterminantBatched<PL, VT, FPVT>::Grad DiracDeterminantBatched<PL,
const int WorkingIndex = iat - FirstIndex;
Grad g = simd::dot(psiMinv_[WorkingIndex], dpsiM[WorkingIndex], NumOrbitals);
ComplexType spin_g = simd::dot(psiMinv_[WorkingIndex], dspin_psiV.data(), NumOrbitals);
assert(checkG(g));
spingrad += spin_g;
return g;
}
Expand Down Expand Up @@ -286,11 +279,6 @@ void DiracDeterminantBatched<PL, VT, FPVT>::mw_evalGradWithSpin(
UpdateEngine::mw_evalGradWithSpin(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc,
mw_res.psiMinv_refs, dpsiM_row_list, mw_dspin, WorkingIndex, grad_now,
spingrad_now);

#ifndef NDEBUG
for (int iw = 0; iw < nw; iw++)
checkG(grad_now[iw]);
#endif
}

template<PlatformKind PL, typename VT, typename FPVT>
Expand Down

0 comments on commit 34fb7dd

Please sign in to comment.