Skip to content

Commit

Permalink
Unify exp and retract, log and inverse_retract (#167)
Browse files Browse the repository at this point in the history
* Initial rework to unify exp/retract and log/inverse_retract to “allocate first”.
* bump version, since this is breaking.
* remove further parts of the allocating dispatch tree.
* Update Project.toml
* bump doc dependencies, remove unneeded code.
* Improve documentation.
* drop support for Julia 1.0
* Update news.md
* Introduce back half-a-layer-2 for the allocatiing functions.

---------

Co-authored-by: Mateusz Baran <[email protected]>
  • Loading branch information
kellertuer and mateuszbaran authored Oct 16, 2023
1 parent 5e7f6ae commit 2d1ac2b
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 658 deletions.
5 changes: 4 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.15.0] xx/xx/2023
## [0.15.0] dd/mm/2023

### Added

Expand All @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- `retract` now behaves like `exp` in the sense that it allocates early,
which reduces the amount of code to dispatch through levels 1-3 twice
- `inverse_retract` now behaves like `log` in the sense that it allocates early
- `Requires.jl` is added as a dependency to facilitate loading some methods related to `ProductManifolds` on Julia 1.6 to 1.8. Later versions rely on package extensions.
- `Documenter.jl` was updated to 1.0.
- `PowerManifold` can now store its size either in a field or in a type, similarly to `DefaultManifold`. By default the size is stored in a field.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ ManifoldsBaseRecursiveArrayToolsExt = "RecursiveArrayTools"

[compat]
DoubleFloats = ">= 0.9.2"
julia = "1.6"
RecursiveArrayTools = "2"
Requires = "1"
julia = "1.0"

[extras]
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
Expand Down
24 changes: 14 additions & 10 deletions docs/src/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ also avoiding ambiguities in multiple dispatch using the [dispatch on one argume

Since the central element for functions on a manifold is the manifold itself, it should always be the first parameter, even for in-place functions. Then the classical parameters of a function (for example a point and a tangent vector for the retraction) follow and the final part are parameters to further dispatch on, which usually have their defaults.

Besides this order the functions follow the scheme “allocate early”, i.e. to switch to the
mutating variant when reasonable, cf. [Mutating and allocating functions](@ref inplace-and-noninplace).

## A 3-Layer architecture for dispatch

The general architecture consists of three layers
Expand All @@ -46,7 +49,6 @@ Note that all other parameters of a function should be as least typed as possibl
With respect to the [dispatch on one argument at a time](https://docs.julialang.org/en/v1/manual/methods/#Dispatch-on-one-argument-at-a-time) paradigm, this layer dispatches the _manifold first_.
We also stay as abstract as possible, for example on the [`AbstractManifold`](@ref) level if possible.


If a function has optional positional arguments, (like [`retract`](@ref)) their default values might be filled/provided on this layer.
This layer ends usually in calling the same functions like [`retract`](@ref) but prefixed with a `_` to enter [Layer II](@ref design-layer2).

Expand Down Expand Up @@ -83,24 +85,26 @@ To close this section, let‘s look at an example.
The high level (or [Layer I](@ref design-layer1)) definition of the retraction is given by

```julia
retract(M::AbstractManifold, p, X, m::AbstractRetractionMethod=default_retraction_method(M, typeof(p))) = _retract(M, p, X, m)
retract!(M::AbstractManifold, q, p, X, m::AbstractRetractionMethod=default_retraction_method(M, typeof(p))) = _retract!(M, q, p, X, m)
```

Note that the convenience function `retract(M, q, p, X, m)` first allocates a `q` before calling this function as well.

This level now dispatches on different retraction types `m`.
It usually passes to specific functions implemented in [Layer III](@ref design-layer3), here for example

```julia
_retract(M::AbstractManifold, p, X, m::Exponentialretraction) = exp(M, p, X)
_retract(M::AbstractManifold, p, X, m::PolarRetraction) = retract_polar(M, p, X)
_retract!(M::AbstractManifold, q, p, X, m::Exponentialretraction) = exp(M, q, p, X)
_retract!(M::AbstractManifold, q, p, X, m::PolarRetraction) = retract_polar(M, q, p, X)
```

where the [`ExponentialRetraction`](@ref) is resolved by again calling a function on [Layer I](@ref design-layer1) (to fill futher default values if these exist). The [`PolarRetraction`](@ref) is dispatched to [`retract_polar`](@ref ManifoldsBase.retract_polar), a function on [Layer III](@ref design-layer3).
where the [`ExponentialRetraction`](@ref) is resolved by again calling a function on [Layer I](@ref design-layer1) (to fill futher default values if these exist). The [`PolarRetraction`](@ref) is dispatched to [`retract_polar!`](@ref ManifoldsBase.retract_polar!), a function on [Layer III](@ref design-layer3).

For further details and dispatches, see [retractions and inverse retractions](@ref sec-retractions) for an overview.

!!! note
The documentation should be attached to the high level functions, since this again fosters ease of use.
If you implement a polar retraction, you should write a method of function [`retract_polar`](@ref ManifoldsBase.retract_polar) but the doc string should be attached to `retract(::M, ::P, ::V, ::PolarRetraction)` for your types `::M, ::P, ::V` of the manifold, points and vectors, respectively.
If you implement a polar retraction, you should write a method of function [`retract_polar!`](@ref ManifoldsBase.retract_polar!) but the doc string should be attached to `retract(::M, ::P, ::V, ::PolarRetraction)` for your types `::M, ::P, ::V` of the manifold, points and vectors, respectively.

To summarize, with respect to the [dispatch on one argument at a time](https://docs.julialang.org/en/v1/manual/methods/#Dispatch-on-one-argument-at-a-time) paradigm, this layer dispatches the (optional) _parameters second_.

Expand All @@ -111,13 +115,13 @@ It should have as few as possible optional parameters and as concrete as possibl

This means

* the function name should be similar to its high level parent (for example [`retract`](@ref) and [`retract_polar`](@ref ManifoldsBase.retract_polar) above)
* the function name should be similar to its high level parent (for example [`retract!`](@ref) and [`retract_polar!`](@ref ManifoldsBase.retract_polar!) above)
* The manifold type in method signature should always be as narrow as possible.
* The points/vectors should either be untyped (for the default representation or if there is only one implementation) or provide all type bounds (for second representations or when using [`AbstractManifoldPoint`](@ref) and [`TVector`](@ref TVector), respectively).

The first step that often happens on this level is memory allocation and calling the in-place function. If faster, it might also implement the function at hand itself.

Usually functions from this layer are not exported, when they have an analogue on the first layer. For example the function [`retract_polar`](@ref ManifoldsBase.retract_polar)`(M, p, X)` is not exported, since when using the interface one would use the [`PolarRetraction`](@ref) or to be precise call [`retract`](@ref)`(M, p, X, PolarRetraction())`.
Usually functions from this layer are not exported, when they have an analogue on the first layer. For example the function [`retract_polar!`](@ref ManifoldsBase.retract_polar!)`(M, q, p, X)` is not exported, since when using the interface one would use the [`PolarRetraction`](@ref) or to be precise call [`retract!`](@ref)`(M, q, p, X, PolarRetraction())`.
When implementing your own manifold, you have to import functions like these anyway.

To summarize, with respect to the [dispatch on one argument at a time](https://docs.julialang.org/en/v1/manual/methods/#Dispatch-on-one-argument-at-a-time) paradigm, this layer dispatches the _concrete manifold and point/vector types last_.
Expand Down Expand Up @@ -182,5 +186,5 @@ log(::M, ::P, ::P)
but the return type would be ``V``, whose internal sizes (fields/arrays) will depend on the concrete type of one of the points. This is accomplished by implementing a method `allocate_result(::M, ::typeof(log), ::P, ::P)` that returns the concrete variable for the result. This way, even with specific types, one just has to implement `log!` and the one line for the allocation.

!!! note
This dispatch from the allocating to the in-place variant happens in Layer III, that is, functions like `exp` or [`retract_polar`](@ref ManifoldsBase.retract_polar) (but not [`retract`](@ref) itself) allocate their result (using `::typeof(retract)` for the second function)
and call the in-place variant `exp!` and [`retract_polar!`](@ref ManifoldsBase.retract_polar!) afterwards.
This dispatch from the allocating to the in-place variant happens in Layer I (which changed in ManifoldsBase.jl 0.15), that is, functions like `exp` or [`retract`](@ref) allocate their result
and call the in-place variant [`exp!`](@ref) and [`retract!`](@ref ManifoldsBase.retract!) afterwards, where the ladder passes down to layer III to reach [`retract_polar!`](@ref ManifoldsBase.retract_polar!).
17 changes: 11 additions & 6 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@ Throughout the documentation of `ManifoldsBase.jl` we might use the [Euclidean S
AbstractManifold
```

which should store information about the manifold, for example parameters inherent to the manifold. The parameters are stored in two possible ways, as a type parameter to dispatch on or as a field. For these the following internal functions exist

```@docs
ManifoldsBase.wrap_type_parameter
ManifoldsBase.TypeParameter
```
which should store information about the manifold, for example parameters inherent to the manifold.

## Points on a manifold

Expand Down Expand Up @@ -63,3 +58,13 @@ Modules = [ManifoldsBase]
Pages = ["numbers.jl"]
Order = [:type, :function]
```

## [Type Parameter](@id type-parameter)

Concrete [`AbstractManifold`](@ref)s usually correspond to families of manifolds that are parameterized by some numbers, for example determining their [`manifold_dimension`](@ref). Those numbers can either be stored in a field or as a type parameter of the structure. The [`TypeParameter`](@ref ManifoldsBase.TypeParameter) offers the flexibility
to have this parameter either as type parameter or a field.

```@docs
ManifoldsBase.TypeParameter
ManifoldsBase.wrap_type_parameter
```
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ if isdefined(Base, :get_extension)
parallel_transport_direction,
parallel_transport_to,
project,
_retract,
riemann_tensor,
submanifold_component,
submanifold_components,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,24 +263,6 @@ function Random.rand(
end
end

function _retract(
M::ProductManifold,
p::ArrayPartition,
X::ArrayPartition,
t::Number,
method::ProductRetraction,
)
return ArrayPartition(
map(
(N, pc, Xc, rm) -> retract(N, pc, Xc, t, rm),
M.manifolds,
submanifold_components(M, p),
submanifold_components(M, X),
method.retractions,
),
)
end

function riemann_tensor(
M::ProductManifold,
p::ArrayPartition,
Expand Down
69 changes: 2 additions & 67 deletions src/point_vector_fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
m::ExponentialRetraction,
)
retract!(M, q.$pfield, p.$pfield, X.$vfield, m)
return X
return q
end

function ManifoldsBase.vector_transport_along!(
Expand Down Expand Up @@ -249,16 +249,11 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
)
end
end
# TODO forward retraction / inverse_retraction
for f_postfix in [:polar, :project, :qr, :softmax]
ra = Symbol("retract_$(f_postfix)")
rm = Symbol("retract_$(f_postfix)!")
push!(
block.args,
quote
function ManifoldsBase.$ra(M::$TM, p::$TP, X::$TV, t::Number)
return $TP(ManifoldsBase.$ra(M, p.$pfield, X.$vfield, t))
end
function ManifoldsBase.$rm(M::$TM, q, p::$TP, X::$TV, t::Number)
ManifoldsBase.$rm(M, q.$pfield, p.$pfield, X.$vfield, t)
return q
Expand All @@ -269,16 +264,6 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
push!(
block.args,
quote
function ManifoldsBase.retract_exp_ode(
M::$TM,
p::$TP,
X::$TV,
t::Number,
m::AbstractRetractionMethod,
B::ManifoldsBase.AbstractBasis,
)
return $TP(ManifoldsBase.retract_exp_ode(M, p.$pfield, X.$vfield, t, m, B))
end
function ManifoldsBase.retract_exp_ode!(
M::$TM,
q::$TP,
Expand All @@ -291,15 +276,6 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
ManifoldsBase.retract_exp_ode!(M, q.$pfield, p.$pfield, X.$vfield, t, m, B)
return q
end
function ManifoldsBase.retract_pade(
M::$TM,
p::$TP,
X::$TV,
t::Number,
m::PadeRetraction,
)
return $TP(ManifoldsBase.retract_pade(M, p.$pfield, X.$vfield, t, m))
end
function ManifoldsBase.retract_pade!(
M::$TM,
q::$TP,
Expand All @@ -311,15 +287,6 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
ManifoldsBase.retract_pade!(M, q.$pfield, p.$pfield, X.$vfield, t, m)
return q
end
function ManifoldsBase.retract_embedded(
M::$TM,
p::$TP,
X::$TV,
t::Number,
m::AbstractRetractionMethod,
)
return $TP(ManifoldsBase.retract_embedded(M, p.$pfield, X.$vfield, t, m))
end
function ManifoldsBase.retract_embedded!(
M::$TM,
q::$TP,
Expand All @@ -331,15 +298,6 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
ManifoldsBase.retract_embedded!(M, q.$pfield, p.$pfield, X.$vfield, t, m)
return q
end
function ManifoldsBase.retract_sasaki(
M::$TM,
p::$TP,
X::$TV,
t::Number,
m::SasakiRetraction,
)
return $TP(ManifoldsBase.retract_sasaki(M, p.$pfield, X.$vfield, t, m))
end
function ManifoldsBase.retract_sasaki!(
M::$TM,
q::$TP,
Expand All @@ -354,12 +312,8 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
end,
)
for f_postfix in [:polar, :project, :qr, :softmax]
ra = Symbol("inverse_retract_$(f_postfix)")
rm = Symbol("inverse_retract_$(f_postfix)!")
push!(block.args, quote
function ManifoldsBase.$ra(M::$TM, p::$TP, q::$TP)
return $TV((ManifoldsBase.$ra)(M, p.$pfield, q.$pfield))
end
function ManifoldsBase.$rm(M::$TM, Y::$TV, p::$TP, q::$TP)
ManifoldsBase.$rm(M, Y.$vfield, p.$pfield, q.$pfield)
return Y
Expand All @@ -369,16 +323,6 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
push!(
block.args,
quote
function ManifoldsBase.inverse_retract_embedded(
M::$TM,
p::$TP,
q::$TP,
m::AbstractInverseRetractionMethod,
)
return $TV(
ManifoldsBase.inverse_retract_embedded(M, p.$pfield, q.$pfield, m),
)
end
function ManifoldsBase.inverse_retract_embedded!(
M::$TM,
X::$TV,
Expand All @@ -395,16 +339,7 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
)
return X
end
function ManifoldsBase.inverse_retract_nlsolve(
M::$TM,
p::$TP,
q::$TP,
m::NLSolveInverseRetraction,
)
return $TV(
ManifoldsBase.inverse_retract_nlsolve(M, p.$pfield, q.$pfield, m),
)
end

function ManifoldsBase.inverse_retract_nlsolve!(
M::$TM,
X::$TV,
Expand Down
Loading

0 comments on commit 2d1ac2b

Please sign in to comment.