Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maybe don't split distributions extension #757

Merged
merged 5 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Install JuliaFormatter and format
run: |
using Pkg
Pkg.add(PackageSpec(name="JuliaFormatter"))
Pkg.add(PackageSpec(name="JuliaFormatter", version="1"))
using JuliaFormatter
format("."; verbose=true)
shell: julia --color=yes {0}
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.10.4] - 2024-10-18

### Changed

* The `Distributions.jl` extension has been split into an extension that additionally requires `RecursiveArrayTools.jl` and one that does not.

## [0.10.3] - 2024-10-04

### Changed
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manifolds"
uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.10.3"
version = "0.10.4"

[deps]
Einsum = "b7d42ee7-0b51-5a75-98ca-779d3107e4c0"
Expand Down Expand Up @@ -35,7 +35,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
ManifoldsBoundaryValueDiffEqExt = "BoundaryValueDiffEq"
ManifoldsDistributionsExt = ["Distributions", "RecursiveArrayTools"]
ManifoldsDistributionsExt = "Distributions"
ManifoldsDistributionsRecursiveArrayToolsExt = ["Distributions", "RecursiveArrayTools"]
ManifoldsHybridArraysExt = "HybridArrays"
ManifoldsNLsolveExt = "NLsolve"
ManifoldsOrdinaryDiffEqDiffEqCallbacksExt = ["DiffEqCallbacks", "OrdinaryDiffEq", "RecursiveArrayTools"]
Expand Down
4 changes: 0 additions & 4 deletions ext/ManifoldsDistributionsExt/ManifoldsDistributionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ if isdefined(Base, :get_extension)
uniform_distribution

using Manifolds: get_iterator, get_parameter, _read, _write

using RecursiveArrayTools: ArrayPartition
else
# imports need to be relative for Requires.jl-based workflows:
# https://github.com/JuliaArrays/ArrayInterface.jl/pull/387
Expand All @@ -30,8 +28,6 @@ else
uniform_distribution

using ..Manifolds: get_iterator, get_parameter, _read, _write

using ..RecursiveArrayTools: ArrayPartition
end

include("distributions.jl")
Expand Down
99 changes: 0 additions & 99 deletions ext/ManifoldsDistributionsExt/distributions_for_manifolds.jl
Original file line number Diff line number Diff line change
@@ -1,103 +1,4 @@

## product manifold

"""
ProductPointDistribution(M::ProductManifold, distributions)

Product distribution on manifold `M`, combined from `distributions`.
"""
struct ProductPointDistribution{
TM<:ProductManifold,
TD<:(NTuple{N,Distribution} where {N}),
} <: MPointDistribution{TM}
manifold::TM
distributions::TD
end

function ProductPointDistribution(M::ProductManifold, distributions::MPointDistribution...)
return ProductPointDistribution{typeof(M),typeof(distributions)}(M, distributions)
end
function ProductPointDistribution(distributions::MPointDistribution...)
M = ProductManifold(map(d -> support(d).manifold, distributions)...)
return ProductPointDistribution(M, distributions...)
end

"""
ProductFVectorDistribution([type::VectorSpaceFiber], [x], distrs...)

Generates a random vector at point `x` from vector space (a fiber of a tangent
bundle) of type `type` using the product distribution of given distributions.

Vector space type and `x` can be automatically inferred from distributions `distrs`.
"""
struct ProductFVectorDistribution{
TSpace<:VectorSpaceFiber{<:Any,<:ProductManifold},
TD<:(NTuple{N,Distribution} where {N}),
} <: FVectorDistribution{TSpace}
type::TSpace
distributions::TD
end

function ProductFVectorDistribution(distributions::FVectorDistribution...)
M = ProductManifold(map(d -> support(d).space.manifold, distributions)...)
fiber_type = support(distributions[1]).space.fiber_type
if !all(d -> support(d).space.fiber_type == fiber_type, distributions)
error(
"Not all distributions have support in vector spaces of the same type, which is currently not supported",
)
end
# Probably worth considering sum spaces in the future?
p = ArrayPartition(map(d -> support(d).space.point, distributions)...)
return ProductFVectorDistribution(Fiber(M, p, fiber_type), distributions)
end

function Random.rand(rng::AbstractRNG, d::ProductFVectorDistribution)
return ArrayPartition(map(d -> rand(rng, d), d.distributions)...)
end
function Random.rand(rng::AbstractRNG, d::ProductPointDistribution)
return ArrayPartition(map(d -> rand(rng, d), d.distributions)...)
end

function Distributions._rand!(
rng::AbstractRNG,
d::ProductFVectorDistribution,
X::ArrayPartition,
)
map(
(t1, t2) -> Distributions._rand!(rng, t1, t2),
d.distributions,
submanifold_components(d.type.manifold, X),
)
return X
end
function Distributions._rand!(
rng::AbstractRNG,
d::ProductPointDistribution,
p::ArrayPartition,
)
map(
(t1, t2) -> Distributions._rand!(rng, t1, t2),
d.distributions,
submanifold_components(d.manifold, p),
)
return p
end

Distributions.support(d::ProductPointDistribution) = MPointSupport(d.manifold)
function Distributions.support(tvd::ProductFVectorDistribution)
return FVectorSupport(tvd.type)
end

function uniform_distribution(M::ProductManifold)
return ProductPointDistribution(M, map(uniform_distribution, M.manifolds))
end
function uniform_distribution(M::ProductManifold, p)
return ProductPointDistribution(
M,
map(uniform_distribution, M.manifolds, submanifold_components(M, p)),
)
end

## power manifold

"""
Expand Down
123 changes: 123 additions & 0 deletions ext/ManifoldsDistributionsRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
module ManifoldsDistributionsRecursiveArrayToolsExt

if isdefined(Base, :get_extension)
using Manifolds
using Distributions
using Random
using LinearAlgebra

import Manifolds: uniform_distribution

using RecursiveArrayTools: ArrayPartition
else
# imports need to be relative for Requires.jl-based workflows:
# https://github.com/JuliaArrays/ArrayInterface.jl/pull/387
using ..Manifolds
using ..Distributions
using ..Random

import ..Manifolds: uniform_distribution

using ..RecursiveArrayTools: ArrayPartition
end

## product manifold

"""
ProductPointDistribution(M::ProductManifold, distributions)

Product distribution on manifold `M`, combined from `distributions`.
"""
struct ProductPointDistribution{
TM<:ProductManifold,
TD<:(NTuple{N,Distribution} where {N}),
} <: MPointDistribution{TM}
manifold::TM
distributions::TD
end

function ProductPointDistribution(M::ProductManifold, distributions::MPointDistribution...)
return ProductPointDistribution{typeof(M),typeof(distributions)}(M, distributions)
end
function ProductPointDistribution(distributions::MPointDistribution...)
M = ProductManifold(map(d -> support(d).manifold, distributions)...)
return ProductPointDistribution(M, distributions...)
end

"""
ProductFVectorDistribution([type::VectorSpaceFiber], [x], distrs...)

Generates a random vector at point `x` from vector space (a fiber of a tangent
bundle) of type `type` using the product distribution of given distributions.

Vector space type and `x` can be automatically inferred from distributions `distrs`.
"""
struct ProductFVectorDistribution{
TSpace<:VectorSpaceFiber{<:Any,<:ProductManifold},
TD<:(NTuple{N,Distribution} where {N}),
} <: FVectorDistribution{TSpace}
type::TSpace
distributions::TD
end

function ProductFVectorDistribution(distributions::FVectorDistribution...)
M = ProductManifold(map(d -> support(d).space.manifold, distributions)...)
fiber_type = support(distributions[1]).space.fiber_type
if !all(d -> support(d).space.fiber_type == fiber_type, distributions)
error(
"Not all distributions have support in vector spaces of the same type, which is currently not supported",
)
end
# Probably worth considering sum spaces in the future?
p = ArrayPartition(map(d -> support(d).space.point, distributions)...)
return ProductFVectorDistribution(Fiber(M, p, fiber_type), distributions)
end

function Random.rand(rng::AbstractRNG, d::ProductFVectorDistribution)
return ArrayPartition(map(d -> rand(rng, d), d.distributions)...)
end
function Random.rand(rng::AbstractRNG, d::ProductPointDistribution)
return ArrayPartition(map(d -> rand(rng, d), d.distributions)...)
end

function Distributions._rand!(
rng::AbstractRNG,
d::ProductFVectorDistribution,
X::ArrayPartition,
)
map(
(t1, t2) -> Distributions._rand!(rng, t1, t2),
d.distributions,
submanifold_components(d.type.manifold, X),
)
return X
end
function Distributions._rand!(
rng::AbstractRNG,
d::ProductPointDistribution,
p::ArrayPartition,
)
map(
(t1, t2) -> Distributions._rand!(rng, t1, t2),
d.distributions,
submanifold_components(d.manifold, p),
)
return p
end

Distributions.support(d::ProductPointDistribution) = MPointSupport(d.manifold)
function Distributions.support(tvd::ProductFVectorDistribution)
return FVectorSupport(tvd.type)
end

function uniform_distribution(M::ProductManifold)
return ProductPointDistribution(M, map(uniform_distribution, M.manifolds))
end
function uniform_distribution(M::ProductManifold, p)
return ProductPointDistribution(
M,
map(uniform_distribution, M.manifolds, submanifold_components(M, p)),
)
end

end
12 changes: 11 additions & 1 deletion src/Manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,12 @@ function __init__()
if exc.f === solve_exp_ode
print(io, "\nDid you forget to load OrdinaryDiffEq? For example: ")
printstyled(io, "`using OrdinaryDiffEq`", color=:cyan)
elseif exc.f === uniform_distribution
print(
io,
"\nDid you forget to load Distributions or RecurisveArrayTools? For example: ",
)
printstyled(io, "`using Distributions`", color=:cyan)
end
end
end
Expand All @@ -610,6 +616,10 @@ function __init__()
include("../ext/ManifoldsBoundaryValueDiffEqExt.jl")
end

@require Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" begin
include("../ext/ManifoldsDistributionsExt/ManifoldsDistributionsExt.jl")
end

@require NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" begin
include("../ext/ManifoldsNLsolveExt.jl")
end
Expand All @@ -630,7 +640,7 @@ function __init__()
)

@require Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" begin
include("../ext/ManifoldsDistributionsExt/ManifoldsDistributionsExt.jl")
include("../ext/ManifoldsDistributionsRecursiveArrayToolsExt.jl")
end

@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
Expand Down
Loading