-
Notifications
You must be signed in to change notification settings - Fork 1
Fisher Information metric #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
e1daf80
a239159
582e0bd
a1d7584
47c1997
4d6fad4
9981d6e
621880e
cb1b998
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,27 +1,119 @@ | ||
| using ManifoldsBase, Manifolds, Static, RecursiveArrayTools, Random, ExponentialFamily | ||
| using FastCholesky | ||
|
|
||
| import ExponentialFamily: exponential_family_typetag | ||
|
|
||
|
|
||
| struct ChartNOrderRetraction{Order,E} <: AbstractRetractionMethod | ||
| extra::E | ||
| end | ||
|
|
||
| function ChartNOrderRetraction{O}() where {O} | ||
| return ChartNOrderRetraction{O,Nothing}(nothing) | ||
| end | ||
|
|
||
| const FirstOrderRetraction = ChartNOrderRetraction{1} | ||
| const SecondOrderRetraction = ChartNOrderRetraction{2} | ||
|
|
||
| """ | ||
| SecondOrderRetraction(; backend=nothing) | ||
|
|
||
| Create a second-order retraction method that uses Christoffel symbols to compute | ||
| a more accurate retraction. If a backend is provided, it will be used for any | ||
| automatic differentiation needed to compute the Christoffel symbols. | ||
|
|
||
| # Arguments | ||
| - `backend`: Optional backend for automatic differentiation (e.g., `ADTypes.AutoForwardDiff()`) | ||
| """ | ||
| function SecondOrderRetraction(; backend=nothing) | ||
| return ChartNOrderRetraction{2,typeof(backend)}(backend) | ||
| end | ||
|
|
||
| """ | ||
| FisherInformationMetric <: RiemannianMetric | ||
|
|
||
| Specifier that we need to use the Fisher information metric. | ||
| """ | ||
| struct FisherInformationMetric{R} <: RiemannianMetric | ||
| default_retraction::R | ||
| end | ||
|
|
||
| function FisherInformationMetric() | ||
| retraction = FirstOrderRetraction() | ||
| return FisherInformationMetric{typeof(retraction)}(retraction) | ||
| end | ||
|
|
||
| """ | ||
| BaseMetric <: RiemannianMetric | ||
|
|
||
| Specifier that we need to use the metric from the base manifold. | ||
| """ | ||
| struct BaseMetric <: RiemannianMetric end | ||
|
|
||
| """ | ||
| getdefaultmetric(::Type{T}) where {T} | ||
|
|
||
| Returns the default metric for the distribution of type `T`. | ||
| """ | ||
| function getdefaultmetric(::Type{T}) where {T} | ||
| return FisherInformationMetric() | ||
| end | ||
|
|
||
| """ | ||
| NaturalParametersManifold(::Type{T}, dims, base, conditioner) | ||
|
|
||
| The manifold for the natural parameters of the distribution of type `T` with dimensions `dims`. | ||
| An internal structure, use `get_natural_manifold` to create an instance of a manifold for the natural parameters of distribution of type `T`. | ||
| """ | ||
| struct NaturalParametersManifold{𝔽,T,D,M,C} <: AbstractDecoratorManifold{𝔽} | ||
| struct NaturalParametersManifold{𝔽,T,D,M,C,MT} <: AbstractDecoratorManifold{𝔽} | ||
| dims::D | ||
| base::M | ||
| conditioner::C | ||
| metric::MT | ||
| end | ||
|
|
||
| getdims(M::NaturalParametersManifold) = M.dims | ||
| getbase(M::NaturalParametersManifold) = M.base | ||
| getconditioner(M::NaturalParametersManifold) = M.conditioner | ||
| getmetric(M::NaturalParametersManifold) = M.metric | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally would prefer
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think it’s exactly the same, thanks |
||
|
|
||
| # The `NaturalParametersManifold` simply adds extra properties to the `base` and | ||
| # acts as a "decorator" | ||
| @inline ManifoldsBase.active_traits(f::F, ::NaturalParametersManifold, ::Any...) where {F} = | ||
| ManifoldsBase.IsExplicitDecorator() | ||
| function select_skip_methods( | ||
| ::F, ::NaturalParametersManifold{𝔽,T,D,MB,C,BaseMetric} | ||
| ) where {F,𝔽,T,D,MB,C} | ||
| return ManifoldsBase.IsExplicitDecorator() | ||
| end | ||
|
|
||
| function select_skip_methods( | ||
| f::F, ::NaturalParametersManifold{𝔽,T,D,MB,C,<:FisherInformationMetric} | ||
| ) where {F,𝔽,T,D,MB,C} | ||
| if f in ( | ||
| ManifoldsBase.retract, | ||
| ManifoldsBase.retract!, | ||
| ManifoldsBase.retract_fused, | ||
| ManifoldsBase.retract_fused!, | ||
| Manifolds.local_metric, | ||
| Manifolds.local_metric_jacobian, | ||
| Manifolds.inverse_local_metric, | ||
| Manifolds.default_retraction_method, | ||
| Manifolds.get_basis_default, | ||
| Manifolds.christoffel_symbols_second, | ||
| Manifolds.christoffel_symbols_first, | ||
| Manifolds.representation_size | ||
| ) | ||
| return ManifoldsBase.EmptyTrait() | ||
| else | ||
| return ManifoldsBase.IsExplicitDecorator() | ||
| end | ||
| end | ||
|
|
||
| @inline function ManifoldsBase.active_traits( | ||
| f::F, M::NaturalParametersManifold, args... | ||
| ) where {F} | ||
| return select_skip_methods(f, M) | ||
| end | ||
|
|
||
| @inline ManifoldsBase.decorated_manifold(M::NaturalParametersManifold) = M.base | ||
|
|
||
| function ExponentialFamily.exponential_family_typetag( | ||
|
|
@@ -31,9 +123,15 @@ function ExponentialFamily.exponential_family_typetag( | |
| end | ||
|
|
||
| function NaturalParametersManifold( | ||
| ::Type{T}, dims::D, base::M, conditioner::C=nothing | ||
| ) where {T,𝔽,D,M<:AbstractManifold{𝔽},C} | ||
| return NaturalParametersManifold{𝔽,T,D,M,C}(dims, base, conditioner) | ||
| ::Type{T}, | ||
| dims::D, | ||
| base::M, | ||
| conditioner::C=nothing, | ||
| metric::MT=getdefaultmetric(T), | ||
| ) where {T,𝔽,D,M<:AbstractManifold{𝔽},C,MT} | ||
| return NaturalParametersManifold{𝔽,T,D,M,C,MT}( | ||
| dims, base, conditioner, metric | ||
| ) | ||
| end | ||
|
|
||
| """ | ||
|
|
@@ -52,9 +150,11 @@ julia> ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, ( | |
| true | ||
| ``` | ||
| """ | ||
| function get_natural_manifold(::Type{T}, dims, conditioner=nothing) where {T} | ||
| function get_natural_manifold( | ||
| ::Type{T}, dims, conditioner=nothing, metric=getdefaultmetric(T) | ||
| ) where {T} | ||
| return NaturalParametersManifold( | ||
| T, dims, get_natural_manifold_base(T, dims, conditioner), conditioner | ||
| T, dims, get_natural_manifold_base(T, dims, conditioner), conditioner, metric | ||
| ) | ||
| end | ||
|
|
||
|
|
@@ -88,3 +188,75 @@ function Base.convert( | |
| exponential_family_typetag(M), p, getconditioner(M), nothing | ||
| ) | ||
| end | ||
|
|
||
|
|
||
| function ManifoldsBase.default_retraction_method( | ||
| M::NaturalParametersManifold{𝔽,TD,D,BM,C,<:FisherInformationMetric}, ::Type{T} | ||
| ) where {𝔽,T,TD,D,BM,C} | ||
| return getmetric(M).default_retraction | ||
| end | ||
|
|
||
| function ManifoldsBase.retract_fused!( | ||
| ::NaturalParametersManifold, q, p, X, t::Number, method::FirstOrderRetraction | ||
| ) | ||
| q .= p .+ t .* X | ||
| return q | ||
| end | ||
|
Comment on lines
+199
to
+204
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to indicate your retraction type atop should indeed be an |
||
|
|
||
| function ManifoldsBase.retract!( | ||
| M::NaturalParametersManifold, q, p, X, method::FirstOrderRetraction | ||
| ) | ||
| return ManifoldsBase.retract_fused!(M, q, p, X, one(eltype(X)), method) | ||
| end | ||
|
|
||
| function ManifoldsBase.retract_fused!( | ||
| M::NaturalParametersManifold{𝔽,T,D,BM,C,<:FisherInformationMetric}, | ||
| q, | ||
| p, | ||
| X, | ||
| t::Number, | ||
| method::SecondOrderRetraction, | ||
| ) where {𝔽,T,D,BM,C} | ||
| basis = ManifoldsBase.get_basis_default(M, p) | ||
| Γ = Manifolds.christoffel_symbols_second(M, p, basis; backend=method.extra) | ||
|
|
||
| Δ = similar(p) | ||
| Manifolds.@einsum Δ[k] = -0.5 * Γ[k, i, j] * (t * X[i]) * (t * X[j]) | ||
| q .= p .+ t .* X .+ Δ | ||
| return q | ||
| end | ||
|
|
||
| function ManifoldsBase.retract!( | ||
| M::NaturalParametersManifold, q, p, X, method::SecondOrderRetraction | ||
| ) | ||
| return ManifoldsBase.retract_fused!(M, q, p, X, one(eltype(X)), method) | ||
| end | ||
|
|
||
| struct NaturalBasis{𝔽,VST<:VectorSpaceType} <: AbstractBasis{𝔽,VST} | ||
| vector_space::VST | ||
| end | ||
|
|
||
| NaturalBasis(𝔽=ℝ, vs::VectorSpaceType=TangentSpaceType()) = NaturalBasis{𝔽,typeof(vs)}(vs) | ||
| function NaturalBasis{𝔽}(vs::VectorSpaceType=TangentSpaceType()) where {𝔽} | ||
| return NaturalBasis{𝔽,typeof(vs)}(vs) | ||
| end | ||
|
|
||
| function ManifoldsBase.get_basis_default( | ||
| ::NaturalParametersManifold{𝔽,T,D,MB,C,<:FisherInformationMetric}, p | ||
| ) where {𝔽,T,D,MB,C} | ||
| return NaturalBasis{𝔽}() | ||
| end | ||
|
|
||
| function Manifolds.local_metric( | ||
| M::NaturalParametersManifold{𝔽,T,D,MB,C,<:FisherInformationMetric}, p, ::NaturalBasis | ||
| ) where {𝔽,T,D,MB,C} | ||
| ef = convert(ExponentialFamilyDistribution, M, p) | ||
| return ExponentialFamily.fisherinformation(ef) | ||
| end | ||
|
|
||
| function Manifolds.inverse_local_metric( | ||
| M::NaturalParametersManifold{𝔽,T,D,MB,C,<:FisherInformationMetric}, p, ::NaturalBasis | ||
| ) where {𝔽,T,D,MB,C} | ||
| ef = convert(ExponentialFamilyDistribution, M, p) | ||
| return cholinv(ExponentialFamily.fisherinformation(ef)) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,4 +4,5 @@ | |
| test_natural_manifold() do rng | ||
| return Beta(10rand(rng), 10rand(rng)) | ||
| end | ||
|
|
||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a retraction type? Then it could subtype https://github.com/JuliaManifolds/ManifoldsBase.jl/blob/85b42907c26df0463f3fc91ba7dafd3fa534f800/src/retractions.jl#L8-L13 ?