Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mssl compilation fixes #428

Closed
wants to merge 56 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
eb770a4
Add grad function member variable
Nov 27, 2023
732fe65
add onnx inpainting test
Nov 27, 2023
f22ead7
Merge branch 'cg_ort_interface' into mm_CRR_in_FB
Nov 27, 2023
3a69bc4
Merge branch 'cg_ort_interface' into mm_CRR_in_FB
Dec 13, 2023
8b0247c
Merge branch 'cg_ort_interface' into mm_CRR_in_FB
Jan 8, 2024
955ddba
gradient as function of image and residual. Rename variables and fix …
Jan 8, 2024
78ee5df
Add image dependence to f_gradient and clarify notation
Jan 8, 2024
4d338be
correct type for l2 norm which doesn't take x
Jan 9, 2024
5ca3ad5
Make f_gradient settable
Jan 9, 2024
45cd200
create combined f_gradient in inpainting test
Jan 9, 2024
3575c7d
Create real indicator function and rename base class
Jan 9, 2024
aff60be
Complex and real implementation; syntax fixes
Jan 9, 2024
7dcf03f
Don't return a temp reference!!
Jan 9, 2024
1f17ead
Comment indicator
Jan 9, 2024
76b8cf5
Merge branch 'development' into mm_CRR_in_FB
Jan 15, 2024
3bea9e2
Move linear operator inside gradient
Jan 17, 2024
8b32eaf
Fix types in l2FB test / class
Jan 17, 2024
7328edb
Merge branch 'development' into mm_CRR_in_FB
May 28, 2024
1d63a65
Cleaning up notation of FB algorithm
May 29, 2024
d2392b1
Update L2 FB to be compatible with f(image, residual)
May 29, 2024
d2f6bed
Clarify inputs and outputs for ProximalFunction type
May 29, 2024
adee12a
spelling and type clarification
May 29, 2024
3f40f3c
Fix type typo
May 30, 2024
ef73570
Experimenting with placement of defn of f_gradient
May 30, 2024
69ec273
Add Phi to gradient explicitly
Jun 19, 2024
79773ac
Call a constructor for the sake of being vaguely intelligible
Jun 19, 2024
f3fad1e
add example ONNX file for cost function
20DM Jun 20, 2024
00ae2f1
harmonise file names
20DM Jun 20, 2024
e224cab
Set properties for convergence
Jun 21, 2024
af9c7d2
Correct log for real indicator
Jun 21, 2024
45e50a9
Non const to allow missing gradient to be set
Jun 21, 2024
59d217c
Remove redundant l2 gradient member
Jun 21, 2024
796f4c3
Fix eigen image comparison
Jun 21, 2024
8c454f2
Compactify the mse calculation
Jun 21, 2024
81f0b74
Put convergence checks back in onnx
Jun 21, 2024
67ab84a
Update comments on non differentiable functions
Jun 21, 2024
88dfc09
Renaming of g_proximal.h to make sense
Jun 21, 2024
5f1a931
Rename other g_proximal headers
Jun 21, 2024
3e91373
Fix include guards
Jun 21, 2024
b066a56
Correcting class comments
Jun 21, 2024
80103e8
Merge remote-tracking branch 'origin/cg_onnxrt_cost' into mm_CRR_in_FB
Jun 24, 2024
ceb12de
More renaming
Jun 25, 2024
6712e33
Adding a differentiable function class
Jun 25, 2024
133d2ab
Updating inpainting tests
Jun 25, 2024
ee91016
Adding headers properly!
Jun 25, 2024
06da525
Improved comments
Jun 25, 2024
03a337d
Add an f(x) = L2 norm class
Jun 25, 2024
0daeae8
Add wavelet transform access
Jun 27, 2024
b9b0e15
Fix gradient operator syntax
Jun 27, 2024
1ce9666
Iteration bug fix
Jul 12, 2024
360ed21
Merge branch 'development' into mm_CRR_in_FB
Jul 12, 2024
d934777
Compatibility. ONNX inpainting works but does not converge
Jul 12, 2024
606c7e5
Convert vector if necessary
Jul 14, 2024
295c610
Add real indicator header
Jul 14, 2024
a2ec0b5
Only compile onnx inpainting if onnxrt on
Jul 14, 2024
d78edb4
MSSL compilation fixes
Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/examples/forward_backward/inpainting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <ctime>

#include "sopt/imaging_forward_backward.h"
#include "sopt/l1_g_proximal.h"
#include "sopt/l1_non_diff_function.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/relative_variation.h"
Expand Down Expand Up @@ -111,7 +111,7 @@ int main(int argc, char const **argv) {
.Psi(psi);

// Once the properties are set, inject it into the ImagingForwardBackward object
fb.g_proximal(gp);
fb.g_function(gp);

SOPT_HIGH_LOG("Starting Forward Backward");
// Alternatively, forward-backward can be called with a tuple (x, residual) as argument
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/forward_backward/inpainting_credible_interval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include "sopt/credible_region.h"
#include "sopt/imaging_forward_backward.h"
#include "sopt/l1_g_proximal.h"
#include "sopt/l1_non_diff_function.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/relative_variation.h"
Expand Down Expand Up @@ -112,7 +112,7 @@ int main(int argc, char const **argv) {
.Psi(psi);

// Once the properties are set, inject it into the ImagingForwardBackward object
fb.g_proximal(gp);
fb.g_function(gp);

SOPT_HIGH_LOG("Starting Forward Backward");
// Alternatively, forward-backward can be called with a tuple (x, residual) as argument
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/forward_backward/inpainting_joint_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <ctime>

#include "sopt/imaging_forward_backward.h"
#include "sopt/l1_g_proximal.h"
#include "sopt/l1_non_diff_function.h"
#include "sopt/joint_map.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
Expand Down Expand Up @@ -111,7 +111,7 @@ int main(int argc, char const **argv) {
.Psi(psi);

// Once the properties are set, inject it into the ImagingForwardBackward object
fb->g_proximal(gp);
fb->g_function(gp);

SOPT_HIGH_LOG("Starting Forward Backward");
// Alternatively, forward-backward can be called with a tuple (x, residual) as argument
Expand Down
17 changes: 14 additions & 3 deletions cpp/sopt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ set(headers
bisection_method.h chained_operators.h credible_region.h
imaging_padmm.h logging.h
forward_backward.h imaging_forward_backward.h
g_proximal.h l1_g_proximal.h joint_map.h
non_differentiable_func.h l1_non_diff_function.h real_indicator.h joint_map.h
differentiable_func.h
imaging_primal_dual.h primal_dual.h
maths.h proximal.h relative_variation.h sdmm.h
wavelets.h conjugate_gradient.h l1_proximal.h padmm.h proximal_expression.h
reweighted.h types.h wrapper.h exception.h linear_transform.h positive_quadrant.h
tv_primal_dual.h gradient_operator.h
l2_forward_backward.h l2_primal_dual.h
real_type.h sampling.h power_method.h objective_functions.h ${PROJECT_BINARY_DIR}/include/sopt/config.h
)
)

set(wavelet_headers
wavelets/direct.h wavelets/indirect.h wavelets/innards.impl.h wavelets/sara.h
Expand All @@ -31,7 +32,9 @@ if(SOPT_MPI)
endif()

if(onnxrt)
list(APPEND headers ort_session.h tf_g_proximal.h)

list(APPEND headers ort_session.h onnx_differentiable_func.h tf_non_diff_function.h)

endif()

add_library(sopt SHARED ${sources})
Expand All @@ -52,6 +55,7 @@ endif()
if(EIGEN3_INCLUDE_DIR)
target_include_directories(sopt SYSTEM PUBLIC ${EIGEN3_INCLUDE_DIR})
endif()
message(STATUS "Eigen3 include dir ${EIGEN3_INCLUDE_DIR}")

if(SOPT_OPENMP)
target_link_libraries(sopt OpenMP::OpenMP_CXX)
Expand All @@ -62,9 +66,16 @@ if(SOPT_MPI)
target_include_directories(sopt SYSTEM PUBLIC ${MPI_CXX_INCLUDE_PATH})
endif()

message(STATUS "onnxrt enabled? ${onnxrt}")
if(onnxrt)
if(${onnxruntime_FOUND})
message(STATUS "ONNX Source dir ${onnxruntime_SOURCE_DIR}")
message(STATUS "ONNX Include dir ${onnxruntime_INCLUDE_DIR}")
endif()
target_link_libraries(sopt ${onnxruntime_LIBRARIES})
target_include_directories(sopt SYSTEM PUBLIC ${onnxruntime_INCLUDE_DIR})
else()
message(STATUS "ONNXrt not found. ONNXrt support disabled.")
endif()

target_link_libraries(sopt ${CONAN_LIBS})
Expand Down
1 change: 1 addition & 0 deletions cpp/sopt/config.in.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <string>
#include <tuple>
#include <cstdint>

namespace sopt {

Expand Down
43 changes: 43 additions & 0 deletions cpp/sopt/differentiable_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef DIFFERENTIABLE_FUNC_H
#define DIFFERENTIABLE_FUNC_H

#include "sopt/forward_backward.h"

// Abstract base class providing the interface for differentiable functions f(x)
// with a defined gradient.
template <typename SCALAR> class DifferentiableFunc
{

public:
using FB = sopt::algorithm::ForwardBackward<SCALAR>;
using Real = typename FB::Real;
using t_Vector = typename FB::t_Vector;
using t_Gradient = typename FB::t_Gradient;
using t_LinearTransform = typename FB::t_LinearTransform;

// A function that prints a log message
virtual void log_message() const = 0;

// Return a function representing the proximal operator for this function.
// Function must be of type t_Proximal, that is
// void proximal_operator(Vector, real, Vector)
virtual t_Gradient gradient() const
{
return [this](t_Vector &output, const t_Vector &image, const t_Vector &residual,
const t_LinearTransform &Phi) -> void { this->gradient(output, image, residual, Phi); };
}

// Calculate the gradient directly
virtual void gradient(t_Vector &output, const t_Vector &image, const t_Vector &residual,
const t_LinearTransform &Phi) const = 0;

// Calculate the function directly
virtual Real function(t_Vector const &image, t_Vector const &y, t_LinearTransform const &Phi) const = 0;

// Transforms input image to a different basis.
// Return linear_transform_identity() if transform not necessary.
//virtual const t_LinearTransform &Phi() const = 0;

};

#endif
66 changes: 41 additions & 25 deletions cpp/sopt/forward_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@

namespace sopt::algorithm {

//! \brief Forward Backward Splitting
//! \details \f$\min_{x} f(\Phi x - y) + g(z)\f$. \f$y\f$ is a target vector.
/*! \brief Forward Backward Splitting
\An optimisation method to solve the problem \f$y = \Phi(x) + N(\sigma)\f
\f$\min_{x} f(x, y, \Phi) + g(x)\f$. \f$y\f$ is a target vector, while x is the current solution.
\f$f$ is a differentiable function. It is necessary to supply the gradient.
\f$f$ is represented using a DifferentiableFunc object, which supplies the function and its gradient.
\f$g$ is a non-differentiable function. It is necessary to supply a proximal operator (or similar stepping function e.g. tensor-flow denoiser).
\f$g$ is represented using a NonDifferentiableFunc object, which supplies the function and its proximal operator.
*/
template <typename SCALAR>
class ForwardBackward {
public:
Expand All @@ -33,7 +39,8 @@ class ForwardBackward {
//! Type of the proximal operator
using t_Proximal = ProximalFunction<Scalar>;
//! Type of the gradient
using t_Gradient = typename std::function<void(t_Vector &, const t_Vector &)>;
// The first argument is the output vector, the others are inputs
using t_Gradient = std::function<void(t_Vector &gradient, const t_Vector &image, const t_Vector &residual, const t_LinearTransform& Phi)>;

//! Values indicating how the algorithm ran
struct Diagnostic {
Expand All @@ -56,8 +63,8 @@ class ForwardBackward {
};

//! Setups ForwardBackward
//! \param[in] f_gradient: gradient of the \f$f\f$ function.
//! \param[in] g_proximal: proximal operator of the \f$g\f$ function
//! \param[in] f_function: the differentiable function \f$f\f$ with a gradient
//! \param[in] g_function: the non-differentiable function \f$g\f$ with a proximal operator
template <typename DERIVED>
ForwardBackward(t_Gradient const &f_gradient, t_Proximal const &g_proximal,
Eigen::MatrixBase<DERIVED> const &target)
Expand All @@ -66,7 +73,7 @@ class ForwardBackward {
beta_(1),
nu_(1),
is_converged_(),
fista_(true),
fista_(true),
Phi_(linear_transform_identity<Scalar>()),
f_gradient_(f_gradient),
g_proximal_(g_proximal),
Expand Down Expand Up @@ -107,9 +114,9 @@ class ForwardBackward {
//! Second proximal
SOPT_MACRO(g_proximal, t_Proximal);
#undef SOPT_MACRO
//! \brief Simplifies calling the proximal of f.
void f_gradient(t_Vector &out, t_Vector const &x) const { f_gradient()(out, x); }
//! \brief Simplifies calling the proximal of f.
//! \brief Simplifies calling the gradient function
void f_gradient(t_Vector &out, t_Vector const &x, t_Vector const &res, t_LinearTransform const &Phi) const { f_gradient()(out, x, res, Phi); }
//! \brief Simplifies calling the proximal function
void g_proximal(t_Vector &out, Real gamma, t_Vector const &x) const {
g_proximal()(out, gamma, x);
}
Expand Down Expand Up @@ -231,15 +238,21 @@ class ForwardBackward {
};

template <typename SCALAR>
void ForwardBackward<SCALAR>::iteration_step(t_Vector &out, t_Vector &residual, t_Vector &p,
t_Vector &z, const t_real lambda) const {
p = out;
f_gradient(z, residual);
const t_Vector input = out - beta() / nu() * (Phi().adjoint() * z);
void ForwardBackward<SCALAR>::iteration_step(t_Vector &image_new, t_Vector &residual, t_Vector &image_current,
t_Vector &gradient_current, const t_real FISTA_ratio) const {
image_current = image_new;
SOPT_LOW_LOG("Calculate gradient");
f_gradient(gradient_current, image_current, residual, Phi()); // takes residual and calculates the grad = 1/sig^2 residual
SOPT_LOW_LOG("Take a gradient step");
image_new = image_current - beta() / nu() * gradient_current; // step to new image using gradient
SOPT_LOW_LOG("Calculate the weight");
const Real weight = gamma() * beta();
g_proximal(out, weight, input);
p = out + lambda * (out - p);
residual = (Phi() * p) - target();
SOPT_LOW_LOG("Apply proximal operator");
g_proximal(image_new, weight, image_new); // apply proximal operator to new image
SOPT_LOW_LOG("FISTA acceleration step");
image_current = image_new + FISTA_ratio * (image_new - image_current); // FISTA acceleration step
SOPT_LOW_LOG("Calculate the residual");
residual = (Phi() * image_current) - target(); // calculates the residual for the NEXT iteration.
}

template <typename SCALAR>
Expand All @@ -253,30 +266,33 @@ typename ForwardBackward<SCALAR>::Diagnostic ForwardBackward<SCALAR>::operator()
}
sanity_check(x_guess, res_guess);

t_Vector p = t_Vector::Zero(x_guess.size());
t_Vector z = t_Vector::Zero(target().size());
const uint image_size = x_guess.size();

t_Vector image_current = x_guess;
t_Vector residual = res_guess;
t_Vector gradient_current = t_Vector::Zero(image_size);
out = x_guess;

t_uint niters(0);
bool converged = false;
Real theta = 1.0;
Real theta_new = 1.0;
Real lambda = 0.0;
Real FISTA_ratio = 0.0;
for (; (not converged) && (niters < itermax()); ++niters) {
SOPT_LOW_LOG(" - [FB] Iteration {}/{}", niters, itermax());
SOPT_HIGH_LOG(" - [FB] Iteration {}/{}", niters, itermax());
if (fista()) {
theta_new = (1 + std::sqrt(1 + 4 * theta * theta)) / 2.;
lambda = (theta - 1) / (theta_new);
FISTA_ratio = (theta - 1) / (theta_new);
theta = theta_new;
}
iteration_step(out, residual, p, z, lambda);
SOPT_LOW_LOG(" - [FB] Sum of residuals: {}", residual.array().abs().sum());
SOPT_HIGH_LOG(" - Call iteration step");
iteration_step(out, residual, image_current, gradient_current, FISTA_ratio);
SOPT_HIGH_LOG(" - [FB] Sum of residuals: {}", residual.array().abs().sum());
converged = is_converged(out, residual);
}

if (converged) {
SOPT_MEDIUM_LOG(" - [FB] converged in {} of {} iterations", niters, itermax());
SOPT_HIGH_LOG(" - [FB] converged in {} of {} iterations", niters, itermax());
} else if (static_cast<bool>(is_converged())) {
// not meaningful if not convergence function
SOPT_ERROR(" - [FB] did not converge within {} iterations", itermax());
Expand Down
Loading
Loading