From 3e207c426c2ed3dfee4fe29a5776a88ca6a70b7e Mon Sep 17 00:00:00 2001 From: krauthaufen Date: Thu, 2 May 2024 09:24:49 +0200 Subject: [PATCH] * added `AddCostFunction` overloads for 5-8 parameter blocks * added `TrySolve` returning whether or not the solution is usable and converged --- RELEASE_NOTES.md | 4 ++++ global.json | 7 +++++++ src/Ceres/CeresRaw.fs | 21 ++++++++++++++++++- src/Ceres/Problem.fs | 37 +++++++++++++++++++++++++++++++-- src/CeresNative/CeresNative.cpp | 37 +++++++++++++++++++++++++++++---- src/CeresNative/CeresNative.h | 6 +++++- 6 files changed, 104 insertions(+), 8 deletions(-) create mode 100644 global.json diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index a16e400..ff38b34 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,3 +1,7 @@ +### 0.9.37 +* added `AddCostFunction` overloads for 5-8 parameter blocks +* added `TrySolve` returning whether or not the solution is usable and converged + ### 0.9.36 * added MaxIterations/Tolerances to CeresBundleIteration * updated Ceres packages and added glfags diff --git a/global.json b/global.json new file mode 100644 index 0000000..d06e2a7 --- /dev/null +++ b/global.json @@ -0,0 +1,7 @@ +{ + "sdk": { + "version": "6.0.204", + "rollForward": "latestMinor", + "allowPrerelease": false + } +} diff --git a/src/Ceres/CeresRaw.fs b/src/Ceres/CeresRaw.fs index 214d52b..2ec0b76 100644 --- a/src/Ceres/CeresRaw.fs +++ b/src/Ceres/CeresRaw.fs @@ -153,6 +153,13 @@ type CeresBundleIteration private(projConstant : int, distConstant : int, camCon functionTolerance, parameterTolerance, gradientTolerance, maxIterations ) +type CeresTerminationType = + | Convergence = 0 + | NoConvergence = 1 + | Failure = 2 + | UserSuccess = 3 + | UserFailure = 4 + module CeresRaw = [] @@ -189,7 +196,19 @@ module CeresRaw = extern void cAddResidualFunction4(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3) [] - extern float cSolve(CeresProblem problem, CeresOptions* options) + extern void cAddResidualFunction5(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3, double* p4) + + [] + extern void cAddResidualFunction6(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5) + + [] + extern void cAddResidualFunction7(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6) + + [] + extern void cAddResidualFunction8(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6, double* p7) + + [] + extern float cSolve(CeresProblem problem, CeresOptions* options, CeresTerminationType* termination, int* usable) [] extern float cOptimizePhotonetwork ( diff --git a/src/Ceres/Problem.fs b/src/Ceres/Problem.fs index 926a440..138b259 100644 --- a/src/Ceres/Problem.fs +++ b/src/Ceres/Problem.fs @@ -145,13 +145,46 @@ type Problem() = | [|p0; p1|] -> CeresRaw.cAddResidualFunction2(handle, loss, fhandle, p0, p1) | [|p0; p1; p2|] -> CeresRaw.cAddResidualFunction3(handle, loss, fhandle, p0, p1, p2) | [|p0; p1; p2; p3|] -> CeresRaw.cAddResidualFunction4(handle, loss, fhandle, p0, p1, p2, p3) + | [|p0; p1; p2; p3; p4|] -> CeresRaw.cAddResidualFunction5(handle, loss, fhandle, p0, p1, p2, p3, p4) + | [|p0; p1; p2; p3; p4; p5|] -> CeresRaw.cAddResidualFunction6(handle, loss, fhandle, p0, p1, p2, p3, p4, p5) + | [|p0; p1; p2; p3; p4; p5; p6|] -> CeresRaw.cAddResidualFunction7(handle, loss, fhandle, p0, p1, p2, p3, p4, p5, p6) + | [|p0; p1; p2; p3; p4; p5; p6; p7|] -> CeresRaw.cAddResidualFunction8(handle, loss, fhandle, p0, p1, p2, p3, p4, p5, p6, p7) | _ -> failwithf "too many parameter-blocks for cost function: %A" parameters.Length member x.Solve(options : Config) = + use termination = fixed [| CeresTerminationType.Convergence |] + use usable = fixed [| 0 |] use pOptions = fixed [| Config.toCeresOptions options |] - let res = CeresRaw.cSolve(handle, pOptions) + let res = CeresRaw.cSolve(handle, pOptions, termination, usable) for b in blocks do b.Mark() - res + + let termination = NativePtr.read termination + match termination with + | CeresTerminationType.Convergence + | CeresTerminationType.UserSuccess -> + res + | _ -> + System.Double.PositiveInfinity + + member x.TrySolve(options : Config) = + use termination = fixed [| CeresTerminationType.Convergence |] + use usable = fixed [| 0 |] + use pOptions = fixed [| Config.toCeresOptions options |] + let res = CeresRaw.cSolve(handle, pOptions, termination, usable) + for b in blocks do b.Mark() + + let termination = NativePtr.read termination + let usable = NativePtr.read usable + + if usable <> 0 then + match termination with + | CeresTerminationType.Convergence + | CeresTerminationType.UserSuccess -> + Some (true, res) + | _ -> + Some (false, res) + else + None member private x.Dispose(disposing : bool) = if disposing then GC.SuppressFinalize(x) diff --git a/src/CeresNative/CeresNative.cpp b/src/CeresNative/CeresNative.cpp index 96d53e2..0d3a1aa 100644 --- a/src/CeresNative/CeresNative.cpp +++ b/src/CeresNative/CeresNative.cpp @@ -34,7 +34,9 @@ DllExport(ceres::LossFunction*) cCreateLossFunction(CeresLossFunction f) default: return new ceres::TrivialLoss(); } -}DllExport(void) cReleaseLossFunction(ceres::LossFunction* f) +} + +DllExport(void) cReleaseLossFunction(ceres::LossFunction* f) { if(f) delete f; } @@ -87,7 +89,31 @@ DllExport(void) cAddResidualFunction4(Problem* problem, ceres::LossFunction* los problem->AddResidualBlock(cost, loss, p0, p1, p2, p3); } -DllExport(double) cSolve(Problem* problem, CeresOptions* options) +DllExport(void) cAddResidualFunction5(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4) +{ + disableGoogleLogging(); + problem->AddResidualBlock(cost, loss, p0, p1, p2, p3, p4); +} + +DllExport(void) cAddResidualFunction6(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5) +{ + disableGoogleLogging(); + problem->AddResidualBlock(cost, loss, p0, p1, p2, p3, p4, p5); +} + +DllExport(void) cAddResidualFunction7(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6) +{ + disableGoogleLogging(); + problem->AddResidualBlock(cost, loss, p0, p1, p2, p3, p4, p5, p6); +} + +DllExport(void) cAddResidualFunction8(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6, double* p7) +{ + disableGoogleLogging(); + problem->AddResidualBlock(cost, loss, p0, p1, p2, p3, p4, p5, p6, p7); +} + +DllExport(double) cSolve(Problem* problem, CeresOptions* options, ceres::TerminationType* status, int* usable) { disableGoogleLogging(); ceres::Solver::Options opt; @@ -106,13 +132,16 @@ DllExport(double) cSolve(Problem* problem, CeresOptions* options) if(options->PrintProgress != 0) printf("%s\n", summary.FullReport().c_str()); - if (summary.termination_type == ceres::TerminationType::CONVERGENCE) + *status = summary.termination_type; + if (summary.IsSolutionUsable()) { + *usable = 1; return summary.final_cost; } else { - return INFINITY; + *usable = 0; + return summary.final_cost; } } diff --git a/src/CeresNative/CeresNative.h b/src/CeresNative/CeresNative.h index 074304a..b66533a 100644 --- a/src/CeresNative/CeresNative.h +++ b/src/CeresNative/CeresNative.h @@ -102,8 +102,12 @@ DllExport(void) cAddResidualFunction1(Problem* problem, ceres::LossFunction* los DllExport(void) cAddResidualFunction2(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1); DllExport(void) cAddResidualFunction3(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2); DllExport(void) cAddResidualFunction4(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3); +DllExport(void) cAddResidualFunction5(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4); +DllExport(void) cAddResidualFunction6(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5); +DllExport(void) cAddResidualFunction7(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6); +DllExport(void) cAddResidualFunction8(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6, double* p7); -DllExport(double) cSolve(Problem* problem, CeresOptions* options); +DllExport(double) cSolve(Problem* problem, CeresOptions* options, ceres::TerminationType* status, int* usable); typedef struct {