diff --git a/src/MGmol.h b/src/MGmol.h index f7525600..5a0ffce3 100644 --- a/src/MGmol.h +++ b/src/MGmol.h @@ -359,8 +359,9 @@ class MGmol : public MGmolInterface } #ifdef MGMOL_HAS_LIBROM - int save_orbital_snapshot(std::string snapshot_dir, OrbitalsType& orbitals); - void project_orbital(std::string snapshot_dir, int rdim, OrbitalsType& orbitals); + int save_orbital_snapshot(std::string file_path, OrbitalsType& orbitals); + void project_orbital(std::string file_path, int rdim, OrbitalsType& orbitals); + void set_orbital(std::string file_path, int rdim, OrbitalsType& orbitals); #endif }; // Instantiate static variables here to avoid clang warnings diff --git a/src/md.cc b/src/md.cc index 314fd0dc..0f79f5dd 100644 --- a/src/md.cc +++ b/src/md.cc @@ -493,7 +493,6 @@ void MGmol::md(OrbitalsType** orbitals, Ions& ions) force(**orbitals, ions); #ifdef MGMOL_HAS_LIBROM - // TODO: cleanup if (ct.getROMOptions().num_orbbasis > 0) { if (onpe0) diff --git a/src/rom.cc b/src/rom.cc index 19ff66a3..7992e0ff 100644 --- a/src/rom.cc +++ b/src/rom.cc @@ -91,6 +91,27 @@ void MGmol::project_orbital(std::string file_path, int rdim, Orbit } } +template +void MGmol::set_orbital(std::string file_path, int rdim, OrbitalsType& orbitals) +{ + const int dim = orbitals.getLocNumpt(); + const int totalSamples = orbitals.chromatic_number(); + + CAROM::BasisReader reader(file_path); + CAROM::Matrix* orbital_basis = reader.getSpatialBasis(rdim); + + Control& ct = *(Control::instance()); + Mesh* mymesh = Mesh::instance(); + pb::GridFunc gf_psi(mymesh->grid(), ct.bcWF[0], ct.bcWF[1], ct.bcWF[2]); + CAROM::Vector psi; + for (int i = 0; i < rdim; ++i) + { + orbital_basis->getColumn(i, psi); + gf_psi.assign(psi.getData()); + orbitals.setPsi(gf_psi, i); + } +} + template class MGmol; template class MGmol; diff --git a/tests/PinnedH2O_3DOF/testPinnedH2O_3DOF.cc b/tests/PinnedH2O_3DOF/testPinnedH2O_3DOF.cc index 1cf93fdf..9e6582cc 100644 --- a/tests/PinnedH2O_3DOF/testPinnedH2O_3DOF.cc +++ b/tests/PinnedH2O_3DOF/testPinnedH2O_3DOF.cc @@ -17,7 +17,6 @@ #ifdef MGMOL_HAS_LIBROM #include "librom.h" -#endif // MGMOL_HAS_LIBROM #include #include @@ -154,8 +153,12 @@ int main(int argc, char** argv) } } - // compute energy and forces again using wavefunctions - // from previous call + // compute energy and forces again with projected problem onto ROM subspace + if (MPIdata::onpe0) + { + std::cout << "Loading ROM basis " << ct.getROMOptions().basis_file << std::endl; + std::cout << "ROM basis dimension = " << ct.getROMOptions().num_orbbasis << std::endl; + } Mesh* mymesh = Mesh::instance(); const pb::Grid& mygrid = mymesh->grid(); @@ -166,19 +169,8 @@ int main(int argc, char** argv) ct.numst, ct.bcWF, projmatrices.get(), nullptr, nullptr, nullptr, nullptr); -#ifdef MGMOL_HAS_LIBROM - CAROM::BasisReader reader(ct.getROMOptions().basis_file); - CAROM::Matrix* Psi = reader.getSpatialBasis(ct.getROMOptions().num_orbbasis); - //mgmol->carom_matrix_to_orbitals(Psi, orbitals); - pb::GridFunc gf_psi(mymesh->grid(), ct.bcWF[0], ct.bcWF[1], ct.bcWF[2]); - CAROM::Vector psi; - for (int i = 0; i < Psi->numColumns(); ++i) - { - Psi->getColumn(i, psi); - gf_psi.assign(psi.getData()); - orbitals.setPsi(gf_psi, i); - } -#endif // MGMOL_HAS_LIBROM + MGmol* mgmol_ = dynamic_cast*>(mgmol); + mgmol_->set_orbital(ct.getROMOptions().basis_file, ct.getROMOptions().num_orbbasis, orbitals); // // evaluate energy and forces again @@ -219,3 +211,4 @@ int main(int argc, char** argv) return 0; } +#endif // MGMOL_HAS_LIBROM