Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
siuwuncheung committed Dec 13, 2024
1 parent b09a20f commit cac12ac
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 19 deletions.
5 changes: 3 additions & 2 deletions src/MGmol.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,9 @@ class MGmol : public MGmolInterface
}

#ifdef MGMOL_HAS_LIBROM
int save_orbital_snapshot(std::string snapshot_dir, OrbitalsType& orbitals);
void project_orbital(std::string snapshot_dir, int rdim, OrbitalsType& orbitals);
int save_orbital_snapshot(std::string file_path, OrbitalsType& orbitals);
void project_orbital(std::string file_path, int rdim, OrbitalsType& orbitals);
void set_orbital(std::string file_path, int rdim, OrbitalsType& orbitals);
#endif
};
// Instantiate static variables here to avoid clang warnings
Expand Down
1 change: 0 additions & 1 deletion src/md.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,6 @@ void MGmol<OrbitalsType>::md(OrbitalsType** orbitals, Ions& ions)
force(**orbitals, ions);

#ifdef MGMOL_HAS_LIBROM
// TODO: cleanup
if (ct.getROMOptions().num_orbbasis > 0)
{
if (onpe0)
Expand Down
21 changes: 21 additions & 0 deletions src/rom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ void MGmol<OrbitalsType>::project_orbital(std::string file_path, int rdim, Orbit
}
}

template <class OrbitalsType>
void MGmol<OrbitalsType>::set_orbital(std::string file_path, int rdim, OrbitalsType& orbitals)
{
const int dim = orbitals.getLocNumpt();
const int totalSamples = orbitals.chromatic_number();

CAROM::BasisReader reader(file_path);
CAROM::Matrix* orbital_basis = reader.getSpatialBasis(rdim);

Control& ct = *(Control::instance());
Mesh* mymesh = Mesh::instance();
pb::GridFunc<ORBDTYPE> gf_psi(mymesh->grid(), ct.bcWF[0], ct.bcWF[1], ct.bcWF[2]);
CAROM::Vector psi;
for (int i = 0; i < rdim; ++i)
{
orbital_basis->getColumn(i, psi);
gf_psi.assign(psi.getData());
orbitals.setPsi(gf_psi, i);
}
}

template class MGmol<LocGridOrbitals>;
template class MGmol<ExtendedGridOrbitals>;

Expand Down
25 changes: 9 additions & 16 deletions tests/PinnedH2O_3DOF/testPinnedH2O_3DOF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

#ifdef MGMOL_HAS_LIBROM
#include "librom.h"
#endif // MGMOL_HAS_LIBROM

#include <cassert>
#include <iostream>
Expand Down Expand Up @@ -154,8 +153,12 @@ int main(int argc, char** argv)
}
}

// compute energy and forces again using wavefunctions
// from previous call
// compute energy and forces again with projected problem onto ROM subspace
if (MPIdata::onpe0)
{
std::cout << "Loading ROM basis " << ct.getROMOptions().basis_file << std::endl;
std::cout << "ROM basis dimension = " << ct.getROMOptions().num_orbbasis << std::endl;
}
Mesh* mymesh = Mesh::instance();
const pb::Grid& mygrid = mymesh->grid();

Expand All @@ -166,19 +169,8 @@ int main(int argc, char** argv)
ct.numst, ct.bcWF, projmatrices.get(), nullptr, nullptr, nullptr,
nullptr);

#ifdef MGMOL_HAS_LIBROM
CAROM::BasisReader reader(ct.getROMOptions().basis_file);
CAROM::Matrix* Psi = reader.getSpatialBasis(ct.getROMOptions().num_orbbasis);
//mgmol->carom_matrix_to_orbitals(Psi, orbitals);
pb::GridFunc<ORBDTYPE> gf_psi(mymesh->grid(), ct.bcWF[0], ct.bcWF[1], ct.bcWF[2]);
CAROM::Vector psi;
for (int i = 0; i < Psi->numColumns(); ++i)
{
Psi->getColumn(i, psi);
gf_psi.assign(psi.getData());
orbitals.setPsi(gf_psi, i);
}
#endif // MGMOL_HAS_LIBROM
MGmol<ExtendedGridOrbitals>* mgmol_ = dynamic_cast<MGmol<ExtendedGridOrbitals>*>(mgmol);
mgmol_->set_orbital(ct.getROMOptions().basis_file, ct.getROMOptions().num_orbbasis, orbitals);

//
// evaluate energy and forces again
Expand Down Expand Up @@ -219,3 +211,4 @@ int main(int argc, char** argv)

return 0;
}
#endif // MGMOL_HAS_LIBROM

0 comments on commit cac12ac

Please sign in to comment.