Skip to content

Commit

Permalink
added switch for tiling of Vpsi and BSH apply in
Browse files Browse the repository at this point in the history
SCF.cc and changed mul_tol in exchangeoperator.h
(0.0 previously)
  • Loading branch information
“hborchert” committed Aug 15, 2024
1 parent bc836f0 commit 215adbd
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 50 deletions.
120 changes: 71 additions & 49 deletions src/madness/chem/SCF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1375,28 +1375,38 @@ vecfuncT SCF::apply_potential(World& world, const tensorT& occ,

// compute Vpsi and truncation
START_TIMER(world);
const bool tile_Vpsi = true;
size_t min_tile = 10;
size_t ntile = std::min(amo.size(), min_tile);
if (!molecule.parameters.pure_ae()) {
gaxpy(world, 1.0, Vpsi, 1.0, gthpseudopotential->apply_potential(world, vloc, amo, occ, enl));
} else {
for (size_t ilo=0; ilo<amo.size(); ilo+=ntile) {
size_t iend = std::min(ilo+ntile,amo.size());
vecfuncT tmpamo(amo.begin()+ilo,amo.begin()+iend);
auto tmpVpsi = mul_sparse(world, vloc, tmpamo, vtol);
print_size(world, tmpVpsi, "tmpVpsi before truncation");

//truncate tmpVpsi
truncate(world, tmpVpsi);
print_size(world, tmpVpsi, "tmpVpsi after truncation");

//put the results into their final home
for (size_t i = ilo; i<iend; ++i){
Vpsi[i] += tmpVpsi[i-ilo];
if (tile_Vpsi){
for (size_t ilo=0; ilo<amo.size(); ilo+=ntile) {
size_t iend = std::min(ilo+ntile,amo.size());
vecfuncT tmpamo(amo.begin()+ilo,amo.begin()+iend);
auto tmpVpsi = mul_sparse(world, vloc, tmpamo, vtol);
print_size(world, tmpVpsi, "tmpVpsi before truncation");

//truncate tmpVpsi
truncate(world, tmpVpsi);
print_size(world, tmpVpsi, "tmpVpsi after truncation");

//put the results into their final home
for (size_t i = ilo; i<iend; ++i){
Vpsi[i] += tmpVpsi[i-ilo];
}
}
END_TIMER(world, "V*psi");
} else {
gaxpy(world, 1.0, Vpsi, 1.0, mul_sparse(world, vloc, amo, vtol));
END_TIMER(world, "V*psi");
START_TIMER(world);
truncate(world, Vpsi);
END_TIMER(world, "Truncate Vpsi");
print_meminfo(world.rank(), "Truncate Vpsi");
}
}
END_TIMER(world, "V*psi");

//START_TIMER(world);
//if (!molecule.parameters.pure_ae()) {
Expand Down Expand Up @@ -1549,57 +1559,69 @@ vecfuncT SCF::compute_residual(World& world, tensorT& occ, tensorT& fock,
fpsi.clear();
std::vector<double> fac(nmo, -2.0);
scale(world, Vpsi, fac);
//std::vector<poperatorT> ops = make_bsh_operators(world, eps);
//set_thresh(world, Vpsi, FunctionDefaults<3>::get_thresh());
END_TIMER(world, "Compute residual stuff");

START_TIMER(world);
const bool tile_applyBSH = true;
vecfuncT new_psi;

//vecfuncT new_psi = apply(world, ops, Vpsi);
if (tile_applyBSH) {
START_TIMER(world);
size_t min_tile = 10;
size_t ntile = std::min(amo.size(), min_tile);
new_psi = zero_functions<double,3>(world, Vpsi.size());

for (size_t ilo=0; ilo<Vpsi.size(); ilo+=ntile) {
size_t iend = std::min(ilo+ntile,Vpsi.size());
vecfuncT tmp_Vpsi(Vpsi.begin()+ilo,Vpsi.begin()+iend);

int tmp_nmo = tmp_Vpsi.size();
tensorT tmp_eps(tmp_nmo);
for (int i = 0; i < tmp_nmo; ++i) {
tmp_eps(i) = std::min(-0.05, fock(i+ilo, i+ilo));
}

// TODO: tile apply
size_t min_tile = 10;
size_t ntile = std::min(amo.size(), min_tile);
vecfuncT new_psi = zero_functions<double,3>(world, Vpsi.size());
std::vector<poperatorT> ops = make_bsh_operators(world, tmp_eps);
set_thresh(world, tmp_Vpsi, FunctionDefaults<3>::get_thresh());

for (size_t ilo=0; ilo<Vpsi.size(); ilo+=ntile) {
size_t iend = std::min(ilo+ntile,Vpsi.size());
vecfuncT tmp_Vpsi(Vpsi.begin()+ilo,Vpsi.begin()+iend);
vecfuncT tmp_new_psi = apply(world, ops, tmp_Vpsi);
print_size(world, tmp_new_psi, "tmp_new_psi before truncation");

int tmp_nmo = tmp_Vpsi.size();
tensorT tmp_eps(tmp_nmo);
for (int i = 0; i < tmp_nmo; ++i) {
tmp_eps(i) = std::min(-0.05, fock(i+ilo, i+ilo));
}
//truncate tmp_new_psi
truncate(world, tmp_new_psi);
print_size(world, tmp_new_psi, "tmp_new_psi after truncation");

std::vector<poperatorT> ops = make_bsh_operators(world, tmp_eps);
set_thresh(world, tmp_Vpsi, FunctionDefaults<3>::get_thresh());
//put the results into their final home
for (size_t i = ilo; i<iend; ++i){
new_psi[i] += tmp_new_psi[i-ilo];
}
ops.clear();
}

vecfuncT tmp_new_psi = apply(world, ops, tmp_Vpsi);
print_size(world, tmp_new_psi, "tmp_new_psi before truncation");
Vpsi.clear();
world.gop.fence();
END_TIMER(world, "Apply BSH");
} else {
START_TIMER(world);

//truncate tmp_new_psi
truncate(world, tmp_new_psi);
print_size(world, tmp_new_psi, "tmp_new_psi after truncation");
std::vector<poperatorT> ops = make_bsh_operators(world, eps);
set_thresh(world, Vpsi, FunctionDefaults<3>::get_thresh());

//put the results into their final home
for (size_t i = ilo; i<iend; ++i){
new_psi[i] += tmp_new_psi[i-ilo];
}
new_psi = apply(world, ops, Vpsi);

ops.clear();
}
Vpsi.clear();
world.gop.fence();

END_TIMER(world, "Apply BSH");
//ops.clear();
Vpsi.clear();
world.gop.fence();
END_TIMER(world, "Apply BSH");

START_TIMER(world);
truncate(world, new_psi);
END_TIMER(world, "Truncate new psi");
}

// Thought it was a bad idea to truncate *before* computing the residual
// but simple tests suggest otherwise ... no more iterations and
// reduced iteration time from truncating.
//START_TIMER(world);
//truncate(world, new_psi);
//END_TIMER(world, "Truncate new psi");

START_TIMER(world);
vecfuncT r = sub(world, psi, new_psi);
Expand Down
4 changes: 3 additions & 1 deletion src/madness/chem/exchangeoperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class Exchange<T,NDIM>::ExchangeImpl {
double lo = 1.e-4;
double thresh = FunctionDefaults<NDIM>::get_thresh();
long printlevel = 0;
double mul_tol = 0.0;
double mul_tol = FunctionDefaults<NDIM>::get_thresh()*0.1;

class MacroTaskExchangeSimple : public MacroTaskOperationBase {

Expand Down Expand Up @@ -345,7 +345,9 @@ class Exchange<T,NDIM>::ExchangeImpl {
const std::vector<Function<T, NDIM>>& mo_ket) {

World& world = vket.front().world();
mul_tol = 0.0;
print("mul_tol ", mul_tol);


resultT Kf = zero_functions_compressed<T, NDIM>(world, 1);
auto poisson = Exchange<double, 3>::ExchangeImpl::set_poisson(world, lo);
Expand Down

0 comments on commit 215adbd

Please sign in to comment.