Skip to content

Commit

Permalink
Merge pull request #389 from seoklab/crdgen-perf
Browse files Browse the repository at this point in the history
feat(algo/optim/bfgs): implement BFGS algorithm
  • Loading branch information
jnooree authored Oct 25, 2024
2 parents 914b15f + f040a14 commit 8e7245b
Show file tree
Hide file tree
Showing 13 changed files with 865 additions and 361 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ CheckOptions:
- key: performance-unnecessary-value-param.IncludeStyle
value: google
- key: performance-unnecessary-value-param.AllowedTypes
value: '^(Eigen::Ref|nuri::(Mut|Const)Ref)'
value: "^(Eigen::Ref|nuri::(Mut|Const)Ref)"
- key: google-readability-braces-around-statements.ShortStatementLines
value: 3
- key: readability-braces-around-statements.ShortStatementLines
Expand Down
38 changes: 38 additions & 0 deletions NOTICE.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,44 @@ Spectra is under the MPLv2 license.
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

### BFGS

- Project URL: <http://github.com/scipy/scipy>
- Full license text:

```txt
Copyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Developers.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

### Kabsch

Our implementation of the Kabsch algorithm is based on the TM-align software.
Expand Down
298 changes: 270 additions & 28 deletions include/nuri/algo/optim.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,51 @@ class LBfgsB;

namespace internal {
constexpr double kEpsMach = 2.220446049250313e-16;
constexpr double kSqrtEpsMach =
1.4901161193847655978721599999999997530663236602513281125266e-8;

enum class DcsrchStatus : std::uint8_t {
kFound,
kConverged,
kContinue,
};

class Dcsrch {
public:
Dcsrch(double f0, double g0, double step0, double stepmin = 0,
double stepmax = 1e+10, double ftol = 1e-3, double gtol = 0.9,
double xtol = 0.1) noexcept;

DcsrchStatus operator()(double f, double g);

double step() const { return step_; }

double finit() const { return finit_; }

double ginit() const { return ginit_; }

private:
constexpr static double kXTrapL = 1.1, kXTrapU = 4;

double finit_, ginit_;
double stepmin_, stepmax_;
double gtest_, gtol_, xtol_;

double step_;

// The variables stx, fx, gx contain the values of the step, function,
// and derivative at the best step.
// The variables sty, fy, gy contain the value of the step, function,
// and derivative at sty.
// The variables stp, f, g contain the values of the step, function,
// and derivative at stp.
double stx_ = 0, fx_, gx_;
double sty_ = 0, fy_, gy_;
double stmin_ = 0, stmax_;
double width_, width1_;

bool brackt_ = false, bisect_ = false;
};

/**
* nbd == 0x1 if has lower bound,
Expand Down Expand Up @@ -103,13 +148,15 @@ namespace internal {

double dtd() const { return dtd_; }

double step() const { return step_; }
double step() const { return dcsrch_.step(); }

double xstep() const { return step_ * dnorm_; }
double xstep() const { return step() * dnorm_; }

double finit() const { return finit_; }
double finit() const { return dcsrch_.finit(); }

double ginit() const { return ginit_; }
double ginit() const { return dcsrch_.ginit(); }

constexpr static double kStepMin = 0, kStepMax = 1e+10;

private:
MutRef<ArrayXd> &x() { return *x_; }
Expand All @@ -122,37 +169,14 @@ namespace internal {

const LbfgsbBounds &bounds() const { return *bounds_; }

enum class DcsrchStatus : std::uint8_t;
DcsrchStatus dcsrch(double f, double g);

void step_x();

// NOLINTNEXTLINE(readability-identifier-naming)
constexpr static double xtrapl_ = 1.1, xtrapu_ = 4, stepmin_ = 0;

MutRef<ArrayXd> *x_;
const ArrayXd *t_, *z_, *d_;
const LbfgsbBounds *bounds_;
double dtd_, dnorm_;

double stepmax_ = 1e+10;
double finit_, ginit_, gtest_;
double gtol_, xtol_;

double step_;

// The variables stx, fx, gx contain the values of the step, function,
// and derivative at the best step.
// The variables sty, fy, gy contain the value of the step, function,
// and derivative at sty.
// The variables stp, f, g contain the values of the step, function,
// and derivative at stp.
double stx_ = 0, fx_, gx_;
double sty_ = 0, fy_, gy_;
double stmin_ = 0, stmax_;
double width_, width1_;

bool brackt_ = false, stage1_ = true;
Dcsrch dcsrch_;
};

struct CauchyBrkpt {
Expand Down Expand Up @@ -548,6 +572,224 @@ LbfgsbResult l_bfgs_b(FuncGrad &&fg, MutRef<ArrayXd> x, const ArrayXi &nbd,
return lbfgsb.minimize(std::forward<FuncGrad>(fg), factr, pgtol, maxiter,
maxls);
}

enum class BfgsResultCode {
kSuccess,
kMaxIterReached,
kInvalidInput,
kAbnormalTerm,
};

struct BfgsResult {
BfgsResultCode code;
int niter;
double fx;
ArrayXd gx;
};

/**
* @brief BFGS minimizer
* @sa bfgs
*
* References:
* - "Broyden-Fletcher-Goldfarb-Shanno algorithm",
* [Wikipedia](https://en.wikipedia.org/wiki/Broyden%E2%80%93Fletcher%E2%80%93Goldfarb%E2%80%93Shanno_algorithm)
* (Accessed 2024-10-25).
*
* This implementation is based on the Python implementation of BFGS in the
* SciPy library, with optimized Hessian update step suggested by the linked
* Wikipedia page. The original implementation is released under the BSD
* 3-Clause License (included below).
*
* \code{.unparsed}
* Copyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Developers.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided
* with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* \endcode
*/
class Bfgs {
public:
Bfgs(MutRef<ArrayXd> x);

template <class FuncGrad>
BfgsResult minimize(FuncGrad fg, const double pgtol = 1e-5,
const double xrtol = 0, int maxiter = -1,
const int maxls = 100, const double ftol = 1e-4,
const double gtol = 0.9, const double xtol = 1e-14) {
using internal::Dcsrch;
using internal::DcsrchStatus;

if (maxiter < 0)
maxiter = 200 * static_cast<int>(x().size());

Hk_.setIdentity();

ArrayXd gfk(x().size());

const double f0 = fg(gfk, x());
double gnorm = gfk.abs().maxCoeff();
if (gnorm <= pgtol)
return { BfgsResultCode::kSuccess, 0, f0, std::move(gfk) };

int k = 0;
double fk = f0, fkm1 = fk + gfk.matrix().norm() * 0.5;
for (; k < maxiter; ++k) {
Dcsrch dcsrch = prepare_lnsrch(gfk, fk, fkm1, ftol, gtol, xtol);
fkm1 = fk;

bool success = false;
for (int iter = 0; iter < maxls; ++iter) {
xk() = x() + dcsrch.step() * pk().array();
fk = fg(gfkp1(), xk());

auto status = dcsrch(fk, gfkp1().matrix().dot(pk()));
if (status == DcsrchStatus::kContinue)
continue;

success = true;
break;
}
if (!success)
return { BfgsResultCode::kAbnormalTerm, k + 1, fk, std::move(gfk) };

bool converged = prepare_next_iter(gfk, dcsrch.step(), pgtol, xrtol);
if (converged)
return { BfgsResultCode::kSuccess, k + 1, fk, std::move(gfk) };
}

return { BfgsResultCode::kMaxIterReached, maxiter, fk, std::move(gfk) };
}

private:
internal::Dcsrch prepare_lnsrch(const ArrayXd &gfk, double fk, double fkm1,
double ftol, double gtol, double xtol);

bool prepare_next_iter(ArrayXd &gfk, double step, double pgtol, double xrtol);

MutRef<ArrayXd> &x() { return x_; }

// NOLINTNEXTLINE(readability-identifier-naming)
Eigen::SelfAdjointView<MatrixXd, Eigen::Upper> Hk() {
return Hk_.selfadjointView<Eigen::Upper>();
}

VectorXd &pk() { return pk_; }

ArrayXd &xk() { return xk_; }
Eigen::MatrixWrapper<ArrayXd> sk() { return xk_.matrix(); }

VectorXd &yk() { return yk_; }

ArrayXd &gfkp1() { return gfkp1_; }
// NOLINTNEXTLINE(readability-identifier-naming)
Eigen::MatrixWrapper<ArrayXd> Hk_yk() { return gfkp1_.matrix(); }

MutRef<ArrayXd> x_;
ArrayXd xk_, gfkp1_;

MatrixXd Hk_;
VectorXd pk_, yk_;
};

/**
* @brief Minimize a function using BFGS algorithm.
*
* @tparam FuncGrad Function object that computes the function value and
* gradient. Function value should be returned and gradient should be
* updated in the input gradient vector.
* @param fg Function object.
* @param x Initial guess. Will be modified in-place.
* @param pgtol Stop when the projected gradient is less than this value.
* @param xrtol Stop when the relative change in x is less than this value.
* @param maxiter Maximum number of iterations. If negative, it will be set to
* 200 times the number of variables.
* @param maxls Maximum number of line search steps.
* @return A struct with the result code, number of iterations, final function
* value, and final gradient.
*
* @note The input `x` will be modified in-place.
* @sa Bfgs
*
* References:
* - "Broyden-Fletcher-Goldfarb-Shanno algorithm",
* [Wikipedia](https://en.wikipedia.org/wiki/Broyden%E2%80%93Fletcher%E2%80%93Goldfarb%E2%80%93Shanno_algorithm)
* (Accessed 2024-10-25).
*
* This implementation is based on the Python implementation of BFGS in the
* SciPy library, with optimized Hessian update step suggested by the linked
* Wikipedia page. The original implementation is released under the BSD
* 3-Clause License (included below).
*
* \code{.unparsed}
* Copyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Developers.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided
* with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* \endcode
*/
template <class FuncGrad>
inline BfgsResult bfgs(FuncGrad &&fg, MutRef<ArrayXd> x,
const double pgtol = 1e-5, const double xrtol = 0,
int maxiter = -1, const int maxls = 100,
const double ftol = 1e-4, const double gtol = 0.9,
const double xtol = 1e-14) {
Bfgs bfgs(x);
return bfgs.minimize(std::forward<FuncGrad>(fg), pgtol, xrtol, maxiter, maxls,
ftol, gtol, xtol);
}
} // namespace nuri

#endif /* NURI_ALGO_OPTIM_H_ */
Loading

0 comments on commit 8e7245b

Please sign in to comment.