Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,24 @@ version = "2.0.0"
[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b"
FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"

[weakdeps]
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"

[compat]
ADTypes = "1.14.0"
BayesBase = "1.3"
ExponentialFamily = "2.0.0"
FastCholesky = "1.3.1"
LinearAlgebra = "1.10"
ManifoldDiff = "0.4.2"
Manifolds = "0.10"
ManifoldsBase = "1"
Random = "1.10"
Expand All @@ -25,6 +32,7 @@ Static = "0.8, 1"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -36,4 +44,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Distributions", "JET", "Test", "ReTestItems", "StableRNGs", "StaticArrays", "Manopt", "ForwardDiff"]
test = ["ADTypes", "Aqua", "Distributions", "JET", "Test", "ReTestItems", "StableRNGs", "StaticArrays", "Manopt", "ForwardDiff"]
188 changes: 180 additions & 8 deletions src/natural_manifolds.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,119 @@
using ManifoldsBase, Manifolds, Static, RecursiveArrayTools, Random, ExponentialFamily
using FastCholesky

import ExponentialFamily: exponential_family_typetag


struct ChartNOrderRetraction{Order,E} <: AbstractRetractionMethod

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra::E
end

function ChartNOrderRetraction{O}() where {O}
return ChartNOrderRetraction{O,Nothing}(nothing)
end

const FirstOrderRetraction = ChartNOrderRetraction{1}
const SecondOrderRetraction = ChartNOrderRetraction{2}

"""
SecondOrderRetraction(; backend=nothing)

Create a second-order retraction method that uses Christoffel symbols to compute
a more accurate retraction. If a backend is provided, it will be used for any
automatic differentiation needed to compute the Christoffel symbols.

# Arguments
- `backend`: Optional backend for automatic differentiation (e.g., `ADTypes.AutoForwardDiff()`)
"""
function SecondOrderRetraction(; backend=nothing)
return ChartNOrderRetraction{2,typeof(backend)}(backend)
end

"""
FisherInformationMetric <: RiemannianMetric

Specifier that we need to use the Fisher information metric.
"""
struct FisherInformationMetric{R} <: RiemannianMetric
default_retraction::R
end

function FisherInformationMetric()
retraction = FirstOrderRetraction()
return FisherInformationMetric{typeof(retraction)}(retraction)
end

"""
BaseMetric <: RiemannianMetric

Specifier that we need to use the metric from the base manifold.
"""
struct BaseMetric <: RiemannianMetric end

"""
getdefaultmetric(::Type{T}) where {T}

Returns the default metric for the distribution of type `T`.
"""
function getdefaultmetric(::Type{T}) where {T}
return FisherInformationMetric()
end

"""
NaturalParametersManifold(::Type{T}, dims, base, conditioner)

The manifold for the natural parameters of the distribution of type `T` with dimensions `dims`.
An internal structure, use `get_natural_manifold` to create an instance of a manifold for the natural parameters of distribution of type `T`.
"""
struct NaturalParametersManifold{𝔽,T,D,M,C} <: AbstractDecoratorManifold{𝔽}
struct NaturalParametersManifold{𝔽,T,D,M,C,MT} <: AbstractDecoratorManifold{𝔽}
dims::D
base::M
conditioner::C
metric::MT
end

getdims(M::NaturalParametersManifold) = M.dims
getbase(M::NaturalParametersManifold) = M.base
getconditioner(M::NaturalParametersManifold) = M.conditioner
getmetric(M::NaturalParametersManifold) = M.metric

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally would prefer get_X methods, since Julia is often snake_case (ok also with counter examples like isapprox)

Copy link
Member Author

@Nimrais Nimrais Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# The `NaturalParametersManifold` simply adds extra properties to the `base` and
# acts as a "decorator"
@inline ManifoldsBase.active_traits(f::F, ::NaturalParametersManifold, ::Any...) where {F} =
ManifoldsBase.IsExplicitDecorator()
function select_skip_methods(
::F, ::NaturalParametersManifold{𝔽,T,D,MB,C,BaseMetric}
) where {F,𝔽,T,D,MB,C}
return ManifoldsBase.IsExplicitDecorator()
end

function select_skip_methods(
f::F, ::NaturalParametersManifold{𝔽,T,D,MB,C,<:FisherInformationMetric}
) where {F,𝔽,T,D,MB,C}
if f in (
ManifoldsBase.retract,
ManifoldsBase.retract!,
ManifoldsBase.retract_fused,
ManifoldsBase.retract_fused!,
Manifolds.local_metric,
Manifolds.local_metric_jacobian,
Manifolds.inverse_local_metric,
Manifolds.default_retraction_method,
Manifolds.get_basis_default,
Manifolds.christoffel_symbols_second,
Manifolds.christoffel_symbols_first,
Manifolds.representation_size
)
return ManifoldsBase.EmptyTrait()
else
return ManifoldsBase.IsExplicitDecorator()
end
end

@inline function ManifoldsBase.active_traits(
f::F, M::NaturalParametersManifold, args...
) where {F}
return select_skip_methods(f, M)
end

@inline ManifoldsBase.decorated_manifold(M::NaturalParametersManifold) = M.base

function ExponentialFamily.exponential_family_typetag(
Expand All @@ -31,9 +123,15 @@ function ExponentialFamily.exponential_family_typetag(
end

function NaturalParametersManifold(
::Type{T}, dims::D, base::M, conditioner::C=nothing
) where {T,𝔽,D,M<:AbstractManifold{𝔽},C}
return NaturalParametersManifold{𝔽,T,D,M,C}(dims, base, conditioner)
::Type{T},
dims::D,
base::M,
conditioner::C=nothing,
metric::MT=getdefaultmetric(T),
) where {T,𝔽,D,M<:AbstractManifold{𝔽},C,MT}
return NaturalParametersManifold{𝔽,T,D,M,C,MT}(
dims, base, conditioner, metric
)
end

"""
Expand All @@ -52,9 +150,11 @@ julia> ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, (
true
```
"""
function get_natural_manifold(::Type{T}, dims, conditioner=nothing) where {T}
function get_natural_manifold(
::Type{T}, dims, conditioner=nothing, metric=getdefaultmetric(T)
) where {T}
return NaturalParametersManifold(
T, dims, get_natural_manifold_base(T, dims, conditioner), conditioner
T, dims, get_natural_manifold_base(T, dims, conditioner), conditioner, metric
)
end

Expand Down Expand Up @@ -88,3 +188,75 @@ function Base.convert(
exponential_family_typetag(M), p, getconditioner(M), nothing
)
end


function ManifoldsBase.default_retraction_method(
M::NaturalParametersManifold{𝔽,TD,D,BM,C,<:FisherInformationMetric}, ::Type{T}
) where {𝔽,T,TD,D,BM,C}
return getmetric(M).default_retraction
end

function ManifoldsBase.retract_fused!(
::NaturalParametersManifold, q, p, X, t::Number, method::FirstOrderRetraction
)
q .= p .+ t .* X
return q
end
Comment on lines +199 to +204

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to indicate your retraction type atop should indeed be an AbstractRetraction subtype – and in the long run this could be documented a bit more. For me natural coordinates seemed a bit magic in the beginning ;)


function ManifoldsBase.retract!(
M::NaturalParametersManifold, q, p, X, method::FirstOrderRetraction
)
return ManifoldsBase.retract_fused!(M, q, p, X, one(eltype(X)), method)
end

function ManifoldsBase.retract_fused!(
M::NaturalParametersManifold{𝔽,T,D,BM,C,<:FisherInformationMetric},
q,
p,
X,
t::Number,
method::SecondOrderRetraction,
) where {𝔽,T,D,BM,C}
basis = ManifoldsBase.get_basis_default(M, p)
Γ = Manifolds.christoffel_symbols_second(M, p, basis; backend=method.extra)

Δ = similar(p)
Manifolds.@einsum Δ[k] = -0.5 * Γ[k, i, j] * (t * X[i]) * (t * X[j])
q .= p .+ t .* X .+ Δ
return q
end

function ManifoldsBase.retract!(
M::NaturalParametersManifold, q, p, X, method::SecondOrderRetraction
)
return ManifoldsBase.retract_fused!(M, q, p, X, one(eltype(X)), method)
end

struct NaturalBasis{𝔽,VST<:VectorSpaceType} <: AbstractBasis{𝔽,VST}
vector_space::VST
end

NaturalBasis(𝔽=ℝ, vs::VectorSpaceType=TangentSpaceType()) = NaturalBasis{𝔽,typeof(vs)}(vs)
function NaturalBasis{𝔽}(vs::VectorSpaceType=TangentSpaceType()) where {𝔽}
return NaturalBasis{𝔽,typeof(vs)}(vs)
end

function ManifoldsBase.get_basis_default(
::NaturalParametersManifold{𝔽,T,D,MB,C,<:FisherInformationMetric}, p
) where {𝔽,T,D,MB,C}
return NaturalBasis{𝔽}()
end

function Manifolds.local_metric(
M::NaturalParametersManifold{𝔽,T,D,MB,C,<:FisherInformationMetric}, p, ::NaturalBasis
) where {𝔽,T,D,MB,C}
ef = convert(ExponentialFamilyDistribution, M, p)
return ExponentialFamily.fisherinformation(ef)
end

function Manifolds.inverse_local_metric(
M::NaturalParametersManifold{𝔽,T,D,MB,C,<:FisherInformationMetric}, p, ::NaturalBasis
) where {𝔽,T,D,MB,C}
ef = convert(ExponentialFamilyDistribution, M, p)
return cholinv(ExponentialFamily.fisherinformation(ef))
end
2 changes: 2 additions & 0 deletions src/natural_manifolds/beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ Converts the `point` to a compatible representation for the natural manifold of
function partition_point(::Type{Beta}, ::Tuple{}, p, conditioner=nothing)
return ArrayPartition(view(p, 1:1), view(p, 2:2))
end

Manifolds.representation_size(::NaturalParametersManifold{𝔽, Beta}) where {𝔽} = (2,)
4 changes: 4 additions & 0 deletions src/natural_manifolds/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ Converts the `point` to a compatible representation for the natural manifold of
function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing)
return ArrayPartition(view(p, 1:(conditioner - 1)), view(p, conditioner:conditioner))
end

function Manifolds.representation_size(M::NaturalParametersManifold{𝔽, Categorical}) where {𝔽}
return (getconditioner(M),)
end
5 changes: 5 additions & 0 deletions src/natural_manifolds/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ function partition_point(::Type{Dirichlet}, dims::Tuple{Int}, p, conditioner=not
# See comment in `get_natural_manifold_base` for `Dirichlet`
return ArrayPartition(p')
end

function Manifolds.representation_size(M::NaturalParametersManifold{𝔽, Dirichlet}) where {𝔽}
dims = getdims(M)
return (first(dims),)
end
2 changes: 2 additions & 0 deletions src/natural_manifolds/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ Converts the `point` to a compatible representation for the natural manifold of
function partition_point(::Type{Gamma}, ::Tuple{}, p, conditioner=nothing)
return ArrayPartition(view(p, 1:1), view(p, 2:2))
end

Manifolds.representation_size(::NaturalParametersManifold{𝔽, Gamma}) where {𝔽} = (2,)
2 changes: 2 additions & 0 deletions src/natural_manifolds/inverse_gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ function partition_point(
)
return ArrayPartition(view(p, 1:1), view(p, 2:2))
end

Manifolds.representation_size(::NaturalParametersManifold{𝔽, ExponentialFamily.GammaInverse}) where {𝔽} = (2,)
2 changes: 2 additions & 0 deletions src/natural_manifolds/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ Converts the `point` to a compatible representation for the natural manifold of
function partition_point(::Type{LogNormal}, ::Tuple{}, p, conditioner=nothing)
return ArrayPartition(view(p, 1:1), view(p, 2:2))
end

Manifolds.representation_size(::NaturalParametersManifold{𝔽, LogNormal}) where {𝔽} = (2,)
10 changes: 10 additions & 0 deletions src/natural_manifolds/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,13 @@ function partition_point(
k = first(dims)
return ArrayPartition(view(p, 1:k), view(p, (k + 1):(k + 1)))
end

function getdefaultmetric(::Type{MvNormalMeanCovariance})
return BaseMetric()
end

Manifolds.representation_size(::NaturalParametersManifold{𝔽, NormalMeanVariance}) where {𝔽} = (2,)

function Manifolds.representation_size(M::NaturalParametersManifold{𝔽, MvNormalMeanScalePrecision}) where {𝔽}
return (M.dims[1] +1, )
end
4 changes: 4 additions & 0 deletions src/natural_manifolds/wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ function partition_point(
k = first(dims)
return ArrayPartition(view(p, 1:1), reshape(view(p, 2:(1 + k^2)), (k, k)))
end

function getdefaultmetric(::Type{ExponentialFamily.WishartFast})
return BaseMetric()
end
22 changes: 22 additions & 0 deletions test/natural_manifolds/bernoulli_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,25 @@
return Bernoulli(rand(rng))
end
end

@testitem "Check SecondOrderRetraction" begin
include("natural_manifolds_setuptests.jl")

using ADTypes: AutoForwardDiff

rng = StableRNG(42)
M = ExponentialFamilyManifolds.get_natural_manifold(Bernoulli, ())
p = rand(rng, M)
X = rand(rng, M)
basis = ExponentialFamilyManifolds.NaturalBasis()

@show Manifolds.local_metric(M, p, basis)
@show Manifolds.local_metric_jacobian(M, p, basis, backend=AutoForwardDiff())
q = retract(
M,
p,
X,
ExponentialFamilyManifolds.SecondOrderRetraction(; backend=AutoForwardDiff()),
)
@show q
end
1 change: 1 addition & 0 deletions test/natural_manifolds/beta_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
test_natural_manifold() do rng
return Beta(10rand(rng), 10rand(rng))
end

end
2 changes: 1 addition & 1 deletion test/natural_manifolds/categorical_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testitem "Check `Categorical` natural manifold" begin
include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
test_natural_manifold(test_injectivity_radius=false) do rng
p = rand(rng, 10)
normalize!(p, 1)
return Categorical(p)
Expand Down
Loading
Loading