Skip to content

Commit

Permalink
exchange operator apply working in low rank form and is faster
Browse files Browse the repository at this point in the history
  • Loading branch information
fbischoff committed Oct 27, 2023
1 parent becd4a7 commit 9ff2820
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 287 deletions.
12 changes: 12 additions & 0 deletions src/madness/chem/CCStructures.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,12 @@ CCConvolutionOperator::init_op(const OpType& type, const Parameters& parameters)
<< " and lo=" << parameters.lo << " and Gamma=" << parameters.gamma << std::endl;
return SlaterF12OperatorPtr(world, parameters.gamma, parameters.lo, parameters.thresh_op);
}
case OpType::OT_F212 : {
if (printme)
std::cout << "Creating " << assign_name(type) << " Operator with thresh=" << parameters.thresh_op
<< " and lo=" << parameters.lo << " and Gamma=" << parameters.gamma << std::endl;
return SlaterF12sqOperatorPtr(world, parameters.gamma, parameters.lo, parameters.thresh_op);
}
case OpType::OT_SLATER : {
if (printme)
std::cout << "Creating " << assign_name(type) << " Operator with thresh=" << parameters.thresh_op
Expand Down Expand Up @@ -491,6 +497,12 @@ assign_name(const OpType& input) {
return "bsh";
case OpType::OT_ONE:
return "identity";
case OpType::OT_GAUSS: /// exp(-r2)
return "gauss";
case OpType::OT_F212: /// (1-exp(-r))^2
return "f12^2";
case OpType::OT_F2G12: /// (1-exp(-r))^2/r = 1/r + exp(-2r)/r - 2 exp(-r)/r
return "f12^2g";
default: {
MADNESS_EXCEPTION("Unvalid enum assignement!", 1);
return "undefined";
Expand Down
4 changes: 3 additions & 1 deletion src/madness/chem/CCStructures.h
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,9 @@ struct CCConvolutionOperator {

friend std::shared_ptr<CCConvolutionOperator> combine(const std::shared_ptr<CCConvolutionOperator>& a,
const std::shared_ptr<CCConvolutionOperator>& b) {
auto bla=combine(*a,*b);
if (a and (not b)) return a;
if ((not a) and b) return b;
if ((not a) and (not b)) return nullptr;
return std::shared_ptr<CCConvolutionOperator>(new CCConvolutionOperator(combine(*a,*b)));
}

Expand Down
70 changes: 25 additions & 45 deletions src/madness/chem/ccpairfunction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,30 +432,20 @@ double CCPairFunction::inner_internal(const CCPairFunction& other, const real_fu
const vector_real_function_3d b = R2.is_initialized() ? R2 * f2.get_b() : copy(world(), f2.get_b());
const pureT& bra=f1.get_function();

auto ops=combine(f1.get_operator_ptr(),f2.get_operator_ptr());
MADNESS_EXCEPTION("still to debug",1);
// if (ops.size()>0) {
// for (const auto& single_op : ops) {
// auto fac = single_op.first;
// auto op = single_op.second;
// double bla=0.0;
// for (int i=0; i<a.size(); ++i) {
// if (op.get_op()) {
// real_function_6d tmp = CompositeFactory<double, 6, 3>(world()).g12(op.get_kernel()).particle1(a[i]).particle2(b[i]);
// bla += fac * inner(bra, tmp);
// } else {
// real_function_6d tmp = CompositeFactory<double, 6, 3>(world()).particle1(a[i]).particle2(b[i]);
// bla += fac * inner(bra,tmp);
// }
// }
// result+=bla;
// }
// } else { // no operators
auto op=combine(f1.get_operator_ptr(),f2.get_operator_ptr());
if (op) {
double bla=0.0;
for (int i=0; i<a.size(); ++i) {
real_function_6d tmp = CompositeFactory<double, 6, 3>(world()).g12(op->get_kernel()).particle1(a[i]).particle2(b[i]);
bla += inner(bra, tmp);
}
result+=bla;
} else { // no operators
for (int i=0; i<a.size(); ++i) {
real_function_6d tmp = CompositeFactory<double, 6, 3>(world()).particle1(a[i]).particle2(b[i]);
result+=inner(bra,tmp);
}
// }
}
} else if (f1.is_decomposed() and f2.is_pure()) { // with or without op
result= f2.inner_internal(f1,R2);

Expand All @@ -469,31 +459,21 @@ double CCPairFunction::inner_internal(const CCPairFunction& other, const real_fu
const vector_real_function_3d b2 = R2.is_initialized() ? R2* f2.get_b() : f2.get_b();


MADNESS_EXCEPTION("still to debug",1);
auto ops=combine(f1.get_operator_ptr(),f2.get_operator_ptr());
// if (ops.size()==0) {
// // <p1 | p2> = \sum_ij <a_i b_i | a_j b_j> = \sum_ij <a_i|a_j> <b_i|b_j>
// result = (matrix_inner(world(), a1, a2)).trace(matrix_inner(world(),b1,b2));
// } else {
// // <a_i b_i | op | a_j b_j> = <a_i * a_j | op(b_i*b_j) >
// for (const auto& single_op : ops) {
// auto fac = single_op.first;
// auto op = single_op.second;
//
// double bla=0.0;
// if (op.get_op()) {
// for (size_t i = 0; i < a1.size(); i++) {
// vector_real_function_3d aa = truncate(a1[i] * a2);
// vector_real_function_3d bb = truncate(b1[i] * b2);
// vector_real_function_3d aopx = op(aa);
// bla += fac * inner(bb, aopx);
// }
// } else {
// bla += fac*(matrix_inner(world(), a1, a2)).trace(matrix_inner(world(),b1,b2));
// }
// result+=bla;
// }
// }
// MADNESS_EXCEPTION("still to debug",1);
auto op=combine(f1.get_operator_ptr(),f2.get_operator_ptr());
if (not op) {
// <p1 | p2> = \sum_ij <a_i b_i | a_j b_j> = \sum_ij <a_i|a_j> <b_i|b_j>
result = (matrix_inner(world(), a1, a2)).trace(matrix_inner(world(),b1,b2));
} else {
// <a_i b_i | op | a_j b_j> = <a_i * a_j | op(b_i*b_j) >
result=0.0;
for (size_t i = 0; i < a1.size(); i++) {
vector_real_function_3d aa = truncate(a1[i] * a2);
vector_real_function_3d bb = truncate(b1[i] * b2);
vector_real_function_3d aopx = (*op)(aa);
result += inner(bb, aopx);
}
}
} else MADNESS_EXCEPTION(
("CCPairFunction Overlap not supported for combination " + f1.name() + " and " + f2.name()).c_str(), 1) ;
return result;
Expand Down
162 changes: 109 additions & 53 deletions src/madness/chem/test_low_rank_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ int test_lowrank_function(World& world, LowRankFunctionParameters parameters) {
LRFunctorF12<double,6> lrfunctor(f12,phi1,phi1);
double cpu0=cpu_time();
auto lrf=LowRankFunctionFactory<double,6>(parameters).project(lrfunctor);
lrf.do_print=true;
// plot_plane<6>(world,lrfunctor,"plot_original."+id,PlotParameters(world).set_plane({"x1","x4"}));
double cpu1=cpu_time();
double error1=lrf.l2error(lrfunctor);
Expand Down Expand Up @@ -172,6 +173,7 @@ int test_Kcommutator(World& world, LowRankFunctionParameters& parameters) {
parameters.print("grid");

real_convolution_3d g12=(CoulombOperator(world,1.e-6,FunctionDefaults<LDIM>::get_thresh()));
g12.particle()=1;
std::shared_ptr<real_convolution_3d> f12ptr;
std::string f12type=parameters.f12type();
if (f12type=="slaterf12") f12ptr.reset(SlaterF12OperatorPtr(world,parameters.gamma(),1.e-6,FunctionDefaults<LDIM>::get_thresh()));
Expand All @@ -184,7 +186,7 @@ int test_Kcommutator(World& world, LowRankFunctionParameters& parameters) {
real_function_3d phi=real_factory_3d(world).f([](const coord_3d& r){return exp(-r.normf());});
double n=phi.norm2();
phi.scale(1/n);
real_function_3d phi_k=phi; // lookys silly, helps reading.
real_function_3d phi_k=phi; // looks silly, helps reading.


// reference term ( < ij | K(1) ) = <Ki(1) j(2) |
Expand Down Expand Up @@ -222,15 +224,17 @@ int test_Kcommutator(World& world, LowRankFunctionParameters& parameters) {
Tensor<double> j_hj = inner(world, phi, hj);
Tensor<double> i_gk = inner(world, phi, gk);
double result_right = j_hj.trace(i_gk);
print(msg, result_right);
print(msg,"norm ", result_right);
print(msg,"error", result_right-reference);
print(msg,"rank ", lrf.rank());
j[msg]=result_right-reference;
j[msg+"_rank"]=lrf.rank();
j[msg+"_compute_time"]=t.tag(msg+"_compute_time");
json2file(j,jsonfilename);
};


{
if (1) {
// lowrankfunction left phi: lrf(1',2) = f12(1',2) i(1')
// K f12 ij = \sum_k k(1) \int g(1,1') f12(1'2) i(1') j(2) k(1') d1'
// = \sum_kr k(1) j(2) \int g(1,1') g_r(1') h_r(2) k(1') d1'
Expand All @@ -254,54 +258,101 @@ int test_Kcommutator(World& world, LowRankFunctionParameters& parameters) {
json2file(j,jsonfilename);
compute_error("left_optimize",fi_one);

fi_one.reorthonormalize();
j["left_reorthonormalize"]=t.tag("left_reorthonormalize");
json2file(j,jsonfilename);
compute_error("left_reorthonormalize",fi_one);
// fi_one.reorthonormalize();
// j["left_reorthonormalize"]=t.tag("left_reorthonormalize");
// json2file(j,jsonfilename);
// compute_error("left_reorthonormalize",fi_one);

LowRankFunction<double,6> phi0(world);
phi0.g={phi};
phi0.h={phi};

timer t2(world);
// this is f12|ij>
auto f12ij=copy(fi_one);
f12ij.h=f12ij.h*phi;
double result1=inner(phi0,f12ij);
print("<ij | f12 | ij>",result1);
t2.tag("multiply 1(1)* phi(1))");

// this is f12|(ki) j>
auto f12kij=copy(f12ij);
f12kij.g=f12kij.g*phi;
double result3=inner(phi0,f12kij);
print("<ij | f12 | (ki) j>",result3);
t2.tag("multiply 1");

// this is g(f12|(ki) j>);
auto gf12kij=copy(f12kij);
gf12kij.g=g12(f12kij.g);
double result2=inner(phi0,gf12kij);
print("<ij | g(f12 | (ki) j>)",result2);
t2.tag("apply g ");

// this is kg(f12|(ki) j>);
auto kgf12kij=copy(gf12kij);
kgf12kij.g=gf12kij.g * phi;
double result4=inner(phi0,kgf12kij);
print("<ij | k g(f12 | (ki) j>)",result4);
t2.tag("apply g ");
}

// // lowrankfunction right phi: lrf(1',2) = f12(1',2) i(1')
// {
// real_function_3d one = real_factory_3d(world).f([](const coord_3d &r) { return 1.0; });
// LowRankFunction<double, 6> fi_one(f12ptr, copy(one), copy(phi));
// fi_one.project(parameters);
// std::swap(fi_one.g,fi_one.h);
// j["right_project_time"]=t.tag("right_project_time");
// json2file(j,jsonfilename);
//
//
// {
// auto gk = mul(world, phi_k, g12(fi_one.g * phi_k)); // function of 1
// auto hj = fi_one.h * phi; // function of 2
// Tensor<double> j_hj = inner(world, phi, hj);
// Tensor<double> i_gk = inner(world, phi, gk);
// double result_right = j_hj.trace(i_gk);
// print("result_right, project only", result_right);
// j["right_project"]=result_right-reference;
// j["right_project_rank"]=fi_one.rank();
// j["left_optimize_compute_time"]=t.tag("left_optimize_compute_time");
// j["right_project_compute_time"]=t.tag("right_project_compute_time");
// }
// json2file(j,jsonfilename);
// std::swap(fi_one.g,fi_one.h);
// fi_one.optimize();
// std::swap(fi_one.g,fi_one.h);
// {
// auto gk = mul(world, phi_k, g12(fi_one.g * phi_k)); // function of 1
// auto hj = fi_one.h * phi; // function of 2
// Tensor<double> j_hj = inner(world, phi, hj);
// Tensor<double> i_gk = inner(world, phi, gk);
// double result_right = j_hj.trace(i_gk);
// print("result_right, optimize", result_right);
// j["right_optimize"]=result_right-reference;
// j["right_optimize_rank"]=fi_one.rank();
// j["right_optimize_compute_time"]=t.tag("right_optimize_compute_time");
// }
// json2file(j,jsonfilename);
//
// }

return 0;
// apply exchange operator in 6d
// if (f12type=="slaterf12") {
if (1) {
// FunctionDefaults<3>::print();
// FunctionDefaults<6>::print();
real_function_6d phi0=CompositeFactory<double,6,3>(world).particle1(phi).particle2(phi);

double thresh=FunctionDefaults<3>::get_thresh();
double dcut=1.e-6;
real_function_6d tmp=TwoElectronFactory(world).dcut(dcut).gamma(parameters.gamma()).f12().thresh(thresh);
real_function_6d f12ij=CompositeFactory<double,6,3>(world).g12(tmp).particle1(copy(phi)).particle2(copy(phi));

timer t(world);
f12ij.fill_tree();
t.tag("exchange: fill_tree");
f12ij.print_size("f12ij");

auto result1=madness::inner(phi0,f12ij);
print("<ij | f12 | ij>", result1);
double reference1=inner(phi*phi,f12(phi*phi));
print("reference <ij |f12 | ij>",reference1);


real_function_6d kf12ij=multiply(f12ij,copy(phi_k),1);
kf12ij.print_size("kf12ij");
t.tag("exchange: multiply 1");

auto result2=madness::inner(phi0,kf12ij);
print("<ij | f12 | (ki) j>", result2);
double reference2=inner(phi*phi,f12(phi*phi*phi));
print("reference <ij |f12 | (ki) j>",reference2);

kf12ij.change_tree_state(reconstructed);
real_function_6d gkf12ij=g12(kf12ij).truncate();
gkf12ij.print_size("gkf12ij");
t.tag("exchange: apply g");

auto result3=madness::inner(phi0,gkf12ij);
print("<ij | g(1'1) f12 | (ki) j>", result3);
double reference3=inner(phi*phi,f12(phi*phi*g12(copy(phi))));
print("reference <ij | g(1'1) f12 | (ki) j>",reference3);


auto exf12ij=multiply(gkf12ij,copy(phi_k),1).truncate();
exf12ij.print_size("exf12ij");
t.tag("exchange: multiply 2");

auto result=madness::inner(phi0,exf12ij);
print("<ij | K1 f12 | ij>", result);
print("error",result-reference);


}

return t1.end();

}

Expand Down Expand Up @@ -648,12 +699,17 @@ int main(int argc, char **argv) {
double thresh = parser.key_exists("thresh") ? std::stod(parser.value("thresh")) : 1.e-5;
FunctionDefaults<6>::set_tensor_type(TT_2D);


FunctionDefaults<3>::set_truncate_mode(1);
FunctionDefaults<6>::set_truncate_mode(1);

FunctionDefaults<1>::set_thresh(thresh);
FunctionDefaults<2>::set_thresh(thresh);
FunctionDefaults<3>::set_thresh(thresh);
FunctionDefaults<4>::set_thresh(thresh);
FunctionDefaults<5>::set_thresh(thresh);
FunctionDefaults<6>::set_thresh(thresh);
FunctionDefaults<6>::set_thresh(1.e-3);

FunctionDefaults<1>::set_k(k);
FunctionDefaults<2>::set_k(k);
Expand All @@ -674,9 +730,9 @@ int main(int argc, char **argv) {
FunctionDefaults<6>::get_thresh());
LowRankFunctionParameters parameters;
parameters.read_and_set_derived_values(world,parser,"grid");
parameters.set_user_defined_value("radius",3.0);
parameters.set_user_defined_value("volume_element",2.e-2);
parameters.set_user_defined_value("tol",1.0e-10);
// parameters.set_user_defined_value("radius",3.0);
// parameters.set_user_defined_value("volume_element",2.e-2);
// parameters.set_user_defined_value("tol",1.0e-10);
parameters.print("grid");
int isuccess=0;
#ifdef USE_GENTENSOR
Expand All @@ -696,8 +752,8 @@ int main(int argc, char **argv) {
// isuccess+=test_inner<1>(world,parameters);
// isuccess+=test_inner<2>(world,parameters);

parameters.set_user_defined_value("volume_element",1.e-1);
isuccess+=test_lowrank_function(world,parameters);
// parameters.set_user_defined_value("volume_element",1.e-1);
// isuccess+=test_lowrank_function(world,parameters);
isuccess+=test_Kcommutator(world,parameters);
} catch (std::exception& e) {
madness::print("an error occured");
Expand Down
2 changes: 2 additions & 0 deletions src/madness/mra/funcimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4444,6 +4444,8 @@ namespace madness {
/// reconstruct this tree -- respects fence
void reconstruct(bool fence);

void change_tree_state(const TreeState finalstate, bool fence=true);

// Invoked on node where key is local
// void reconstruct_op(const keyT& key, const tensorT& s);
void reconstruct_op(const keyT& key, const coeffT& s, const bool accumulate_NS=true);
Expand Down
Loading

0 comments on commit 9ff2820

Please sign in to comment.