Skip to content

Commit

Permalink
collect multiple restart files in a format.
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamer2368 committed Apr 20, 2024
1 parent 1b40a1d commit cf08928
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 12 deletions.
11 changes: 7 additions & 4 deletions examples/Carbyne/carbyne.rom.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ pseudopotential=pseudo.C_ONCV_PBE_SG15
[Run]
type=QUENCH
[Quench]
max_steps=10
max_steps=5
atol=1.e-8
[Orbitals]
initial_type=Fourier
[Restart]
output_level=4
# input_level=4
# input_filename=snapshot0_24_109_22_36
input_level=4
input_filename=snapshot0_000

[ROM.offline]
restartFilename=snapshot0_24_109_22_39
restart_filefmt=snapshot0_%03d
restart_min_idx=0
restart_max_idx=1
basis_file=carom
15 changes: 13 additions & 2 deletions src/Control.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2063,7 +2063,11 @@ void Control::setROMOptions(const boost::program_options::variables_map& vm)

if (onpe0)
{
rom_pri_option.restart_filename = vm["ROM.offline.restartFilename"].as<std::string>();
rom_pri_option.restart_file_fmt = vm["ROM.offline.restart_filefmt"].as<std::string>();
rom_pri_option.restart_file_minidx = vm["ROM.offline.restart_min_idx"].as<int>();
rom_pri_option.restart_file_maxidx = vm["ROM.offline.restart_max_idx"].as<int>();

rom_pri_option.basis_file = vm["ROM.offline.basis_file"].as<std::string>();
} // onpe0

// synchronize all processors
Expand All @@ -2077,7 +2081,8 @@ void Control::syncROMOptions()

MGmol_MPI& mmpi = *(MGmol_MPI::instance());

mmpi.bcast(rom_pri_option.restart_filename, comm_global_);
mmpi.bcast(rom_pri_option.restart_file_fmt, comm_global_);
mmpi.bcast(rom_pri_option.basis_file, comm_global_);

auto bcast_check = [](int mpirc) {
if (mpirc != MPI_SUCCESS)
Expand All @@ -2092,5 +2097,11 @@ void Control::syncROMOptions()
mpirc = MPI_Bcast(&rom_stage, 1, MPI_SHORT, 0, comm_global_);
bcast_check(mpirc);

mpirc = MPI_Bcast(&rom_pri_option.restart_file_minidx, 1, MPI_INT, 0, comm_global_);
bcast_check(mpirc);

mpirc = MPI_Bcast(&rom_pri_option.restart_file_maxidx, 1, MPI_INT, 0, comm_global_);
bcast_check(mpirc);

rom_pri_option.rom_stage = static_cast<ROMStage>(rom_stage);
}
5 changes: 4 additions & 1 deletion src/rom_Control.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ enum class ROMStage
/* Stored as a private member variable of Control class */
struct ROMPrivateOptions
{
std::string restart_filename = "";
std::string restart_file_fmt = "";
int restart_file_minidx = -1;
int restart_file_maxidx = -1;
std::string basis_file = "";
ROMStage rom_stage = ROMStage::UNSUPPORTED;
};

Expand Down
49 changes: 46 additions & 3 deletions src/rom_workflows.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,63 @@
// Please also read this link https://github.com/llnl/mgmol/LICENSE

#include "rom_workflows.h"
#include <memory>
#include <string>
#include <stdexcept>

template<typename ... Args>
std::string string_format( const std::string& format, Args ... args )
{
int size_s = std::snprintf( nullptr, 0, format.c_str(), args ... ) + 1; // Extra space for '\0'
if( size_s <= 0 ){ throw std::runtime_error( "Error during formatting." ); }
auto size = static_cast<size_t>( size_s );
std::unique_ptr<char[]> buf( new char[ size ] );
std::snprintf( buf.get(), size, format.c_str(), args ... );
return std::string( buf.get(), buf.get() + size - 1 ); // We don't want the '\0' inside
}

template <class OrbitalsType>
void readRestartFiles(MGmolInterface *mgmol_)
{
Control& ct = *(Control::instance());
ROMPrivateOptions rom_options = ct.getROMOptions();
assert(rom_options.restart_file_minidx >= 0);
assert(rom_options.restart_file_maxidx >= 0);
const int minidx = rom_options.restart_file_minidx;
const int maxidx = rom_options.restart_file_maxidx;
const int num_restart = maxidx - minidx + 1;

MGmol<OrbitalsType> *mgmol = static_cast<MGmol<OrbitalsType> *>(mgmol_);
OrbitalsType *orbitals = nullptr;
std::string filename;

OrbitalsType *orbitals = mgmol->loadOrbitalFromRestartFile(rom_options.restart_filename);
/* Read the first snapshot to determin dimension and number of snapshots */
filename = string_format(rom_options.restart_file_fmt, minidx);
orbitals = mgmol->loadOrbitalFromRestartFile(filename);
const int dim = orbitals->getLocNumpt();
const int chrom_num = orbitals->chromatic_number();
const int totalSamples = orbitals->chromatic_number() * num_restart;
delete orbitals;

mgmol->save_orbital_snapshot("test", *orbitals);
/* Initialize libROM classes */
CAROM::Options svd_options(dim, totalSamples, 1);
CAROM::BasisGenerator basis_generator(svd_options, false, rom_options.basis_file);

/* Collect the restart files */
for (int k = minidx; k <= maxidx; k++)
{
filename = string_format(rom_options.restart_file_fmt, k);
orbitals = mgmol->loadOrbitalFromRestartFile(filename);
assert(dim == orbitals->getLocNumpt());
assert(chrom_num == orbitals->chromatic_number());

delete orbitals;
for (int i = 0; i < chrom_num; ++i)
basis_generator.takeSample(orbitals->getPsi(i));

delete orbitals;
}
basis_generator.writeSnapshot();
basis_generator.endSamples();
}

template void readRestartFiles<LocGridOrbitals>(MGmolInterface *mgmol_);
Expand Down
10 changes: 8 additions & 2 deletions src/tools/OptionDescription.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,12 @@ void setupHiddenOption(po::options_description &hidden)
void setupROMConfigOption(po::options_description &rom_cfg)
{
rom_cfg.add_options()
("ROM.offline.restartFilename", po::value<string>()->required(),
"File name to read for snapshots.");
("ROM.offline.restart_filefmt", po::value<string>()->required(),
"File name format to read for snapshots.")
("ROM.offline.restart_min_idx", po::value<int>()->required(),
"Minimum index for snapshot file format.")
("ROM.offline.restart_max_idx", po::value<int>()->required(),
"Maximum index for snapshot file format.")
("ROM.offline.basis_file", po::value<string>()->required(),
"File name for libROM snapshot/POD matrices.");
}

0 comments on commit cf08928

Please sign in to comment.