Skip to content

Commit

Permalink
Merge pull request #4819 from camelto2/change_VariableSet_to_RealType
Browse files Browse the repository at this point in the history
Convert VariableSet to RealType only
  • Loading branch information
ye-luo authored Nov 8, 2023
2 parents d27cab8 + cf0c220 commit 864c22d
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 67 deletions.
2 changes: 1 addition & 1 deletion src/Optimize/OptimizeBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CostFunctionBase

virtual int getNumParams() const = 0;

virtual Return_t& Params(int i) = 0;
virtual Return_rt& Params(int i) = 0;

virtual Return_t Params(int i) const = 0;

Expand Down
4 changes: 2 additions & 2 deletions src/QMCDrivers/WFOpt/QMCCostFunctionBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ void QMCCostFunctionBase::reportParametersH5()
if (!myComm->rank())
{
int ci_size = 0;
std::vector<opt_variables_type::value_type> CIcoeff;
std::vector<opt_variables_type::real_type> CIcoeff;
for (int i = 0; i < OptVariables.size(); i++)
{
std::array<char, 128> Coeff;
Expand All @@ -271,7 +271,7 @@ void QMCCostFunctionBase::reportParametersH5()
if (ci_size > 0)
{
CI_Opt = true;
newh5 = RootName + ".opt.h5";
newh5 = RootName + ".opt.h5";
*msg_stream << " <Ci Coeffs saved in opt_coeffs=\"" << newh5 << "\">" << std::endl;
hdf_archive hout;
hout.create(newh5, H5F_ACC_TRUNC);
Expand Down
4 changes: 2 additions & 2 deletions src/QMCDrivers/WFOpt/QMCCostFunctionBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class QMCCostFunctionBase : public CostFunctionBase<QMCTraits::RealType>, public
SUM_INDEX_SIZE
};

using EffectiveWeight = QMCTraits::QTFull::RealType;
using EffectiveWeight = QMCTraits::QTFull::RealType;
using FullPrecRealType = QMCTraits::FullPrecRealType;
///Constructor.
QMCCostFunctionBase(ParticleSet& w, TrialWaveFunction& psi, QMCHamiltonian& h, Communicate* comm);
Expand All @@ -85,7 +85,7 @@ class QMCCostFunctionBase : public CostFunctionBase<QMCTraits::RealType>, public
///Path and name of the HDF5 prefix where CI coeffs are saved
std::string newh5;
///assign optimization parameter i
Return_t& Params(int i) override { return OptVariables[i]; }
Return_rt& Params(int i) override { return OptVariables[i]; }
///return optimization parameter i
Return_t Params(int i) const override { return OptVariables[i]; }
int getType(int i) const { return OptVariables.getType(i); }
Expand Down
14 changes: 7 additions & 7 deletions src/QMCDrivers/WFOpt/QMCFixedSampleLinearOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ bool QMCFixedSampleLinearOptimize::adaptive_three_shift_run()
<< std::endl;

// for each set of shifts, solve the linear method equations for the parameter update direction
std::vector<std::vector<ValueType>> parameterDirections;
std::vector<std::vector<RealType>> parameterDirections;
#ifdef HAVE_LMY_ENGINE
// call the engine to perform update
EngineObj->wfn_update_compute();
Expand All @@ -1070,7 +1070,7 @@ bool QMCFixedSampleLinearOptimize::adaptive_three_shift_run()
if (true)
{
for (int j = 0; j < N; j++)
parameterDirections.at(i).at(j) = EngineObj->wfn_update().at(i * N + j);
parameterDirections.at(i).at(j) = std::real(EngineObj->wfn_update().at(i * N + j));
}
else
parameterDirections.at(i).at(0) = 1.0;
Expand All @@ -1080,7 +1080,7 @@ bool QMCFixedSampleLinearOptimize::adaptive_three_shift_run()
optTarget->setneedGrads(false);

// prepare vectors to hold the initial and current parameters
std::vector<ValueType> currParams(numParams, 0.0);
std::vector<RealType> currParams(numParams, 0.0);

// initialize the initial and current parameter vectors
for (int i = 0; i < numParams; i++)
Expand Down Expand Up @@ -1168,8 +1168,8 @@ bool QMCFixedSampleLinearOptimize::adaptive_three_shift_run()
}

// find the best shift and the corresponding update direction
const std::vector<ValueType>* bestDirection = 0;
int best_shift = -1;
const std::vector<RealType>* bestDirection = 0;
int best_shift = -1;
for (int k = 0; k < costValues.size() && std::abs((initCost - initCost) / initCost) < max_relative_cost_change; k++)
if (is_best_cost(k, costValues, shifts_i, initCost) && good_update.at(k))
{
Expand Down Expand Up @@ -1440,7 +1440,7 @@ bool QMCFixedSampleLinearOptimize::descent_run()

for (int i = 0; i < results.size(); i++)
{
optTarget->Params(i) = results[i];
optTarget->Params(i) = std::real(results[i]);
}

//If descent is being run as part of a hybrid optimization, need to check if a vector of
Expand Down Expand Up @@ -1488,7 +1488,7 @@ bool QMCFixedSampleLinearOptimize::hybrid_run()
app_log() << "Update descent engine parameter values after Blocked LM step" << std::endl;
for (int i = 0; i < optTarget->getNumParams(); i++)
{
ValueType val = optTarget->Params(i);
RealType val = optTarget->Params(i);
descentEngineObj->setParamVal(i, val);
}
}
Expand Down
14 changes: 7 additions & 7 deletions src/QMCDrivers/WFOpt/QMCFixedSampleLinearOptimizeBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
<< std::endl;

// for each set of shifts, solve the linear method equations for the parameter update direction
std::vector<std::vector<ValueType>> parameterDirections;
std::vector<std::vector<RealType>> parameterDirections;
#ifdef HAVE_LMY_ENGINE
// call the engine to perform update
EngineObj->wfn_update_compute();
Expand All @@ -1360,7 +1360,7 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
if (true)
{
for (int j = 0; j < N; j++)
parameterDirections.at(i).at(j) = EngineObj->wfn_update().at(i * N + j);
parameterDirections.at(i).at(j) = std::real(EngineObj->wfn_update().at(i * N + j));
}
else
parameterDirections.at(i).at(0) = 1.0;
Expand All @@ -1370,7 +1370,7 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
//There will be updates of 0 for parameters that were filtered out before derivative ratios were used by the engine.
if (options_LMY_.filter_param)
{
std::vector<std::vector<ValueType>> tmpParameterDirections;
std::vector<std::vector<RealType>> tmpParameterDirections;
tmpParameterDirections.resize(shifts_i.size());

for (int i = 0; i < shifts_i.size(); i++)
Expand Down Expand Up @@ -1400,7 +1400,7 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
optTarget->setneedGrads(false);

// prepare vectors to hold the initial and current parameters
std::vector<ValueType> currParams(numParams, 0.0);
std::vector<RealType> currParams(numParams, 0.0);

// initialize the initial and current parameter vectors
for (int i = 0; i < numParams; i++)
Expand Down Expand Up @@ -1493,8 +1493,8 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
}

// find the best shift and the corresponding update direction
const std::vector<ValueType>* bestDirection = 0;
int best_shift = -1;
const std::vector<RealType>* bestDirection = 0;
int best_shift = -1;
for (int k = 0;
k < costValues.size() && std::abs((initCost - initCost) / initCost) < options_LMY_.max_relative_cost_change; k++)
if (is_best_cost(k, costValues, shifts_i, initCost) && good_update.at(k))
Expand Down Expand Up @@ -1778,7 +1778,7 @@ bool QMCFixedSampleLinearOptimizeBatched::descent_run()

for (int i = 0; i < results.size(); i++)
{
optTarget->Params(i) = results[i];
optTarget->Params(i) = std::real(results[i]);
}

//If descent is being run as part of a hybrid optimization, need to check if a vector of
Expand Down
4 changes: 2 additions & 2 deletions src/QMCDrivers/tests/test_DescentEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ TEST_CASE("DescentEngine RMSprop update", "[drivers][descent]")
optimize::VariableSet myVars;

//Two fake parameters are specified
optimize::VariableSet::value_type first_param(1.0);
optimize::VariableSet::value_type second_param(-2.0);
optimize::VariableSet::real_type first_param(1.0);
optimize::VariableSet::real_type second_param(-2.0);

myVars.insert("first", first_param);
myVars.insert("second", second_param);
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/Fermion/SlaterDetBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,10 @@ std::unique_ptr<MultiSlaterDetTableMethod> SlaterDetBuilder::createMSDFast(
Optimizable = CI_Optimizable = true;
if (csf_data_ptr)
for (int i = 1; i < csf_data_ptr->coeffs.size(); i++)
myVars.insert(CItags[i], csf_data_ptr->coeffs[i], true, optimize::LINEAR_P);
myVars.insert(CItags[i], std::real(csf_data_ptr->coeffs[i]), true, optimize::LINEAR_P);
else
for (int i = 1; i < C.size(); i++)
myVars.insert(CItags[i], C[i], true, optimize::LINEAR_P);
myVars.insert(CItags[i], std::real(C[i]), true, optimize::LINEAR_P);
}
else
{
Expand Down
8 changes: 4 additions & 4 deletions src/QMCWaveFunctions/VariableSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void VariableSet::insertFrom(const VariableSet& input)

void VariableSet::insertFromSum(const VariableSet& input_1, const VariableSet& input_2)
{
value_type sum_val;
real_type sum_val;
std::string vname;

// Check that objects to be summed together have the same number of active
Expand Down Expand Up @@ -94,7 +94,7 @@ void VariableSet::insertFromSum(const VariableSet& input_1, const VariableSet& i

void VariableSet::insertFromDiff(const VariableSet& input_1, const VariableSet& input_2)
{
value_type diff_val;
real_type diff_val;
std::string vname;

// Check that objects to be subtracted have the same number of active
Expand Down Expand Up @@ -259,7 +259,7 @@ void VariableSet::writeToHDF(const std::string& filename, qmcplusplus::hdf_archi

hout.push("name_value_lists");

std::vector<qmcplusplus::QMCTraits::ValueType> param_values;
std::vector<qmcplusplus::QMCTraits::RealType> param_values;
std::vector<std::string> param_names;
for (auto& pair_it : NameAndValue)
{
Expand Down Expand Up @@ -292,7 +292,7 @@ void VariableSet::readFromHDF(const std::string& filename, qmcplusplus::hdf_arch
throw std::runtime_error(err_msg.str());
}

std::vector<qmcplusplus::QMCTraits::ValueType> param_values;
std::vector<qmcplusplus::QMCTraits::RealType> param_values;
hin.read(param_values, "parameter_values");

std::vector<std::string> param_names;
Expand Down
13 changes: 6 additions & 7 deletions src/QMCWaveFunctions/VariableSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ enum
*/
struct VariableSet
{
using value_type = qmcplusplus::QMCTraits::ValueType;
using real_type = qmcplusplus::QMCTraits::RealType;
using real_type = qmcplusplus::QMCTraits::RealType;

using pair_type = std::pair<std::string, value_type>;
using pair_type = std::pair<std::string, real_type>;
using index_pair_type = std::pair<std::string, int>;
using iterator = std::vector<pair_type>::iterator;
using const_iterator = std::vector<pair_type>::const_iterator;
Expand Down Expand Up @@ -131,7 +130,7 @@ struct VariableSet
return -1;
}

inline void insert(const std::string& vname, value_type v, bool enable = true, int type = OTHER_P)
inline void insert(const std::string& vname, real_type v, bool enable = true, int type = OTHER_P)
{
iterator loc = find(vname);
int ind_loc = loc - NameAndValue.begin();
Expand Down Expand Up @@ -169,7 +168,7 @@ struct VariableSet

/** equivalent to std::map<std::string,T>[string] operator
*/
inline value_type& operator[](const std::string& vname)
inline real_type& operator[](const std::string& vname)
{
iterator loc = find(vname);
if (loc == NameAndValue.end())
Expand All @@ -192,12 +191,12 @@ struct VariableSet
/** return the i-th value
* @param i index
*/
inline value_type operator[](int i) const { return NameAndValue[i].second; }
inline real_type operator[](int i) const { return NameAndValue[i].second; }

/** assign the i-th value
* @param i index
*/
inline value_type& operator[](int i) { return NameAndValue[i].second; }
inline real_type& operator[](int i) { return NameAndValue[i].second; }

/** get the i-th parameter's type
* @param i index
Expand Down
43 changes: 10 additions & 33 deletions src/QMCWaveFunctions/tests/test_variable_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


#include "catch.hpp"
#include "complex_approx.hpp"

#include "VariableSet.h"
#include "io/hdf/hdf_archive.h"
Expand All @@ -20,7 +19,6 @@
#include <string>

using std::string;
using qmcplusplus::ValueApprox;

namespace optimize
{
Expand All @@ -37,7 +35,7 @@ TEST_CASE("VariableSet empty", "[optimize]")
TEST_CASE("VariableSet one", "[optimize]")
{
VariableSet vs;
VariableSet::value_type first_val(1.123456789);
VariableSet::real_type first_val(1.123456789);
vs.insert("first", first_val);
std::vector<std::string> names{"first"};
vs.activate(names.begin(), names.end(), true);
Expand All @@ -47,44 +45,31 @@ TEST_CASE("VariableSet one", "[optimize]")
REQUIRE(vs.getIndex("first") == 0);
REQUIRE(vs.name(0) == "first");
double first_val_real = 1.123456789;
CHECK(std::real(vs[0] ) == Approx(first_val_real));
CHECK(vs[0] == Approx(first_val_real));

std::ostringstream o;
vs.print(o, 0, false);
//std::cout << o.str() << std::endl;
#ifdef QMC_COMPLEX
REQUIRE(o.str() == "first (1.123457e+00,0.000000e+00) 0 1 ON 0\n");
#else
REQUIRE(o.str() == "first 1.123457e+00 0 1 ON 0\n");
#endif

std::ostringstream o2;
vs.print(o2, 1, true);
//std::cout << o2.str() << std::endl;

#ifdef QMC_COMPLEX
char formatted_output[] = " Name Value Type Recompute Use Index\n"
" ----- ---------------------------- ---- --------- --- -----\n"
" first (1.123457e+00,0.000000e+00) 0 1 ON 0\n";


REQUIRE(o2.str() == formatted_output);
#else
char formatted_output[] = " Name Value Type Recompute Use Index\n"
" ----- ---------------------------- ---- --------- --- -----\n"
" first 1.123457e+00 0 1 ON 0\n";


REQUIRE(o2.str() == formatted_output);
#endif
}

TEST_CASE("VariableSet output", "[optimize]")
{
VariableSet vs;
VariableSet::value_type first_val(11234.56789);
VariableSet::value_type second_val(0.000256789);
VariableSet::value_type third_val(-1.2);
VariableSet::real_type first_val(11234.56789);
VariableSet::real_type second_val(0.000256789);
VariableSet::real_type third_val(-1.2);
vs.insert("s", first_val);
vs.insert("second", second_val);
vs.insert("really_long_name", third_val);
Expand All @@ -95,29 +80,21 @@ TEST_CASE("VariableSet output", "[optimize]")
vs.print(o, 0, true);
//std::cout << o.str() << std::endl;

#ifdef QMC_COMPLEX
char formatted_output[] = " Name Value Type Recompute Use Index\n"
"---------------- ---------------------------- ---- --------- --- -----\n"
" s (1.123457e+04,0.000000e+00) 0 1 ON 0\n"
" second (2.567890e-04,0.000000e+00) 0 1 ON 1\n"
"really_long_name (-1.200000e+00,0.000000e+00) 0 1 ON 2\n";
#else
char formatted_output[] = " Name Value Type Recompute Use Index\n"
"---------------- ---------------------------- ---- --------- --- -----\n"
" s 1.123457e+04 0 1 ON 0\n"
" second 2.567890e-04 0 1 ON 1\n"
"really_long_name -1.200000e+00 0 1 ON 2\n";
#endif

REQUIRE(o.str() == formatted_output);
}

TEST_CASE("VariableSet HDF output and input", "[optimize]")
{
VariableSet vs;
VariableSet::value_type first_val(11234.56789);
VariableSet::value_type second_val(0.000256789);
VariableSet::value_type third_val(-1.2);
VariableSet::real_type first_val(11234.56789);
VariableSet::real_type second_val(0.000256789);
VariableSet::real_type third_val(-1.2);
vs.insert("s", first_val);
vs.insert("second", second_val);
vs.insert("really_really_really_long_name", third_val);
Expand All @@ -129,8 +106,8 @@ TEST_CASE("VariableSet HDF output and input", "[optimize]")
vs2.insert("second", 0.0);
qmcplusplus::hdf_archive hin;
vs2.readFromHDF("vp.h5", hin);
CHECK(vs2.find("s")->second == ValueApprox(first_val));
CHECK(vs2.find("second")->second == ValueApprox(second_val));
CHECK(vs2.find("s")->second == Approx(first_val));
CHECK(vs2.find("second")->second == Approx(second_val));
// This value as in the file, but not in the VariableSet that loaded the file,
// so the value does not get added.
CHECK(vs2.find("really_really_really_long_name") == vs2.end());
Expand Down

0 comments on commit 864c22d

Please sign in to comment.