Skip to content

Commit

Permalink
* added AddCostFunction overloads for 5-8 parameter blocks
Browse files Browse the repository at this point in the history
* added `TrySolve` returning whether or not the solution is usable and converged
  • Loading branch information
krauthaufen committed May 2, 2024
1 parent b7b2e50 commit 3e207c4
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 8 deletions.
4 changes: 4 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
### 0.9.37
* added `AddCostFunction` overloads for 5-8 parameter blocks
* added `TrySolve` returning whether or not the solution is usable and converged

### 0.9.36
* added MaxIterations/Tolerances to CeresBundleIteration
* updated Ceres packages and added glfags
Expand Down
7 changes: 7 additions & 0 deletions global.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"sdk": {
"version": "6.0.204",
"rollForward": "latestMinor",
"allowPrerelease": false
}
}
21 changes: 20 additions & 1 deletion src/Ceres/CeresRaw.fs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ type CeresBundleIteration private(projConstant : int, distConstant : int, camCon
functionTolerance, parameterTolerance, gradientTolerance, maxIterations
)

type CeresTerminationType =
| Convergence = 0
| NoConvergence = 1
| Failure = 2
| UserSuccess = 3
| UserFailure = 4

module CeresRaw =

[<Literal>]
Expand Down Expand Up @@ -189,7 +196,19 @@ module CeresRaw =
extern void cAddResidualFunction4(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3)

[<DllImport(lib); SuppressUnmanagedCodeSecurity>]
extern float cSolve(CeresProblem problem, CeresOptions* options)
extern void cAddResidualFunction5(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3, double* p4)

[<DllImport(lib); SuppressUnmanagedCodeSecurity>]
extern void cAddResidualFunction6(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5)

[<DllImport(lib); SuppressUnmanagedCodeSecurity>]
extern void cAddResidualFunction7(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6)

[<DllImport(lib); SuppressUnmanagedCodeSecurity>]
extern void cAddResidualFunction8(CeresProblem problem, CeresLossFunctionHandle loss, CeresCostFunction func, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6, double* p7)

[<DllImport(lib); SuppressUnmanagedCodeSecurity>]
extern float cSolve(CeresProblem problem, CeresOptions* options, CeresTerminationType* termination, int* usable)

[<DllImport(lib); SuppressUnmanagedCodeSecurity>]
extern float cOptimizePhotonetwork (
Expand Down
37 changes: 35 additions & 2 deletions src/Ceres/Problem.fs
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,46 @@ type Problem() =
| [|p0; p1|] -> CeresRaw.cAddResidualFunction2(handle, loss, fhandle, p0, p1)
| [|p0; p1; p2|] -> CeresRaw.cAddResidualFunction3(handle, loss, fhandle, p0, p1, p2)
| [|p0; p1; p2; p3|] -> CeresRaw.cAddResidualFunction4(handle, loss, fhandle, p0, p1, p2, p3)
| [|p0; p1; p2; p3; p4|] -> CeresRaw.cAddResidualFunction5(handle, loss, fhandle, p0, p1, p2, p3, p4)
| [|p0; p1; p2; p3; p4; p5|] -> CeresRaw.cAddResidualFunction6(handle, loss, fhandle, p0, p1, p2, p3, p4, p5)
| [|p0; p1; p2; p3; p4; p5; p6|] -> CeresRaw.cAddResidualFunction7(handle, loss, fhandle, p0, p1, p2, p3, p4, p5, p6)
| [|p0; p1; p2; p3; p4; p5; p6; p7|] -> CeresRaw.cAddResidualFunction8(handle, loss, fhandle, p0, p1, p2, p3, p4, p5, p6, p7)
| _ -> failwithf "too many parameter-blocks for cost function: %A" parameters.Length

member x.Solve(options : Config) =
use termination = fixed [| CeresTerminationType.Convergence |]
use usable = fixed [| 0 |]
use pOptions = fixed [| Config.toCeresOptions options |]
let res = CeresRaw.cSolve(handle, pOptions)
let res = CeresRaw.cSolve(handle, pOptions, termination, usable)
for b in blocks do b.Mark()
res

let termination = NativePtr.read termination
match termination with
| CeresTerminationType.Convergence
| CeresTerminationType.UserSuccess ->
res
| _ ->
System.Double.PositiveInfinity

member x.TrySolve(options : Config) =
use termination = fixed [| CeresTerminationType.Convergence |]
use usable = fixed [| 0 |]
use pOptions = fixed [| Config.toCeresOptions options |]
let res = CeresRaw.cSolve(handle, pOptions, termination, usable)
for b in blocks do b.Mark()

let termination = NativePtr.read termination
let usable = NativePtr.read usable

if usable <> 0 then
match termination with
| CeresTerminationType.Convergence
| CeresTerminationType.UserSuccess ->
Some (true, res)
| _ ->
Some (false, res)
else
None

member private x.Dispose(disposing : bool) =
if disposing then GC.SuppressFinalize(x)
Expand Down
37 changes: 33 additions & 4 deletions src/CeresNative/CeresNative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ DllExport(ceres::LossFunction*) cCreateLossFunction(CeresLossFunction f)
default: return new ceres::TrivialLoss();
}

}DllExport(void) cReleaseLossFunction(ceres::LossFunction* f)
}

DllExport(void) cReleaseLossFunction(ceres::LossFunction* f)
{
if(f) delete f;
}
Expand Down Expand Up @@ -87,7 +89,31 @@ DllExport(void) cAddResidualFunction4(Problem* problem, ceres::LossFunction* los
problem->AddResidualBlock(cost, loss, p0, p1, p2, p3);
}

DllExport(double) cSolve(Problem* problem, CeresOptions* options)
DllExport(void) cAddResidualFunction5(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4)
{
disableGoogleLogging();
problem->AddResidualBlock(cost, loss, p0, p1, p2, p3, p4);
}

DllExport(void) cAddResidualFunction6(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5)
{
disableGoogleLogging();
problem->AddResidualBlock(cost, loss, p0, p1, p2, p3, p4, p5);
}

DllExport(void) cAddResidualFunction7(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6)
{
disableGoogleLogging();
problem->AddResidualBlock(cost, loss, p0, p1, p2, p3, p4, p5, p6);
}

DllExport(void) cAddResidualFunction8(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6, double* p7)
{
disableGoogleLogging();
problem->AddResidualBlock(cost, loss, p0, p1, p2, p3, p4, p5, p6, p7);
}

DllExport(double) cSolve(Problem* problem, CeresOptions* options, ceres::TerminationType* status, int* usable)
{
disableGoogleLogging();
ceres::Solver::Options opt;
Expand All @@ -106,13 +132,16 @@ DllExport(double) cSolve(Problem* problem, CeresOptions* options)

if(options->PrintProgress != 0) printf("%s\n", summary.FullReport().c_str());

if (summary.termination_type == ceres::TerminationType::CONVERGENCE)
*status = summary.termination_type;
if (summary.IsSolutionUsable())
{
*usable = 1;
return summary.final_cost;
}
else
{
return INFINITY;
*usable = 0;
return summary.final_cost;
}

}
Expand Down
6 changes: 5 additions & 1 deletion src/CeresNative/CeresNative.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,12 @@ DllExport(void) cAddResidualFunction1(Problem* problem, ceres::LossFunction* los
DllExport(void) cAddResidualFunction2(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1);
DllExport(void) cAddResidualFunction3(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2);
DllExport(void) cAddResidualFunction4(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3);
DllExport(void) cAddResidualFunction5(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4);
DllExport(void) cAddResidualFunction6(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5);
DllExport(void) cAddResidualFunction7(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6);
DllExport(void) cAddResidualFunction8(Problem* problem, ceres::LossFunction* loss, CustomCostFunction* cost, double* p0, double* p1, double* p2, double* p3, double* p4, double* p5, double* p6, double* p7);

DllExport(double) cSolve(Problem* problem, CeresOptions* options);
DllExport(double) cSolve(Problem* problem, CeresOptions* options, ceres::TerminationType* status, int* usable);


typedef struct {
Expand Down

0 comments on commit 3e207c4

Please sign in to comment.