Skip to content

Commit a9ab220

Browse files
authored
Add WoodburyEstimator for high dimensionality (#94)
* Add `WoodburyEstimator` for high dimensionality When the covariance matrix is too large to handle in full, this provides the option to model it as `Σ = σ²I + U * Λ * U'` for some low-rank `U` and diagonal `Λ`. * Address review comments * Fix errors I swapped two of the norms (whoops) and made one typo. * Add complete docs on Woodbury * Fix c formula
1 parent 3c916a6 commit a9ab220

File tree

12 files changed

+411
-4
lines changed

12 files changed

+411
-4
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
name = "CovarianceEstimation"
22
uuid = "587fd27a-f159-11e8-2dae-1979310e6154"
33
authors = ["Mateusz Baran <[email protected]>", "Thibaut Lienart"]
4-
version = "0.2.11"
4+
version = "0.2.12"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
99
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
10+
TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e"
11+
WoodburyMatrices = "efce3f68-66dc-5838-9240-27a6d6f5f9b6"
1012

1113
[compat]
1214
LinearAlgebra = "1"
1315
Statistics = "1"
1416
StatsBase = "0.33, 0.34"
17+
WoodburyMatrices = "1"
18+
TSVD = "0.4"
1519
julia = "1.6"
1620

1721
[extras]

docs/src/assets/donoho_fig3.png

202 KB
Loading

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ A package for robustly estimating covariance matrices of real-valued data.
77
## Package Features
88

99
- Standard corrected and uncorrected covariance estimators,
10-
- Linear and Nonlinear shrinkage estimators
10+
- Linear and Nonlinear shrinkage estimators, including estimators for covariance matrices too large to store in dense form
1111
- Focus on speed and lightweight dependencies
1212

1313
## Manual outline

docs/src/lib/public.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,7 @@ PerfectPositiveCorrelation
2929
ConstantCorrelation
3030
AnalyticalNonlinearShrinkage
3131
BiweightMidcovariance
32+
NormLossCov
33+
StatLossCov
34+
WoodburyEstimator
3235
```

docs/src/man/nlshrink.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,26 @@ F = eigen(X)
88
F.U*(d̃ .* F.U') # d̃ is a vector of transformed eigenvalues
99
```
1010

11-
Currently, only the analytical nonlinear shrinkage ([`AnalyticalNonlinearShrinkage`](@ref)) method is implemented.
11+
Currently, there are two flavors of analytical nonlinear shrinkage:
12+
- [`AnalyticalNonlinearShrinkage`](@ref) is recommended in cases where the covariance matrix can be stored as a dense matrix
13+
- for cases where the covariance matrix is too large to handle in dense form, [`WoodburyEstimator`](@ref) models the covariance matrix as
14+
15+
Σ = σ²I + U * Λ * U'
16+
17+
where `σ` is a scalar, `I` is the identity matrix, `U` is a low-rank semi-orthogonal matrix, and `Λ` is diagonal.
18+
One can readily compute with this representation via the [Woodbury matrix identity](https://en.wikipedia.org/wiki/Woodbury_matrix_identity) and the [WoodburyMatrices package](https://github.com/JuliaLinearAlgebra/WoodburyMatrices.jl).
19+
This formulation approximates the covariance matrix as if all but a few (largest) eigenvalues are equal to `σ²`.
20+
A [truncated singular value decomposition](https://github.com/JuliaLinearAlgebra/TSVD.jl) of the data matrix is
21+
performed and the corresponding eigenvalues are shrunk by optimal methods for a wide variety of loss functions:
22+
23+
- [`NormLossCov`](@ref) allows you to specify that you want to minimize some notion of loss against the "true" covariance matrix
24+
- [`StatLossCov`](@ref) allows you to optimize for certain specific statistical outcomes, e.g., optimizing the accuracy of Mahalanobis distances.
25+
26+
The eigenvalue shrinkage function is plotted for all choices below:
27+
28+
![Donoho et al Fig 3](../assets/donoho_fig3.png)
29+
30+
For complete details, see:
31+
32+
Donoho, D.L., Gavish, M. and Johnstone, I.M., 2018.
33+
Optimal shrinkage of eigenvalues in the spiked covariance model. Annals of statistics, 46(4), p.1742.

src/CovarianceEstimation.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using Statistics
44
using StatsBase
55
using LinearAlgebra
66
import StatsBase: cov
7+
using WoodburyMatrices
8+
using TSVD
79

810
export cov
911
export CovarianceEstimator, SimpleCovariance,
@@ -14,12 +16,18 @@ export CovarianceEstimator, SimpleCovariance,
1416
# Eigendecomposition-based methods
1517
AnalyticalNonlinearShrinkage,
1618
# Biweight midcovariance
17-
BiweightMidcovariance
19+
BiweightMidcovariance,
20+
# Woodbury-based methods
21+
WoodburyEstimator,
22+
# Loss functions
23+
NormLossCov, StatLossCov
1824

1925

2026
include("utils.jl")
27+
include("loss.jl")
2128
include("biweight.jl")
2229
include("linearshrinkage.jl")
2330
include("nonlinearshrinkage.jl")
31+
include("woodbury.jl")
2432

2533
end # module

src/loss.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Nonlinear shrinkage estimation can be described in terms of a loss function measuring the error between
2+
# the target and the sample eigenvalues. In:
3+
# Donoho, D.L., Gavish, M. and Johnstone, I.M., 2018.
4+
# Optimal shrinkage of eigenvalues in the spiked covariance model. Annals of statistics, 46(4), p.1742.
5+
# there is a systematic analysis of different loss functions; their classification scheme is encoded here.
6+
7+
# Implementation note:
8+
# While we could parametrize all these different loss functions in the type system and use dispatch,
9+
# that would induce needless specialization: every function that took a `LossFunction` would have to be
10+
# specialized, even if the the argument encoding the loss function is merely "passed-through."
11+
# So instead, we hide details from the type system, dividing the Donoho classification scheme into just two types,
12+
# `NormLossCov` and `StatLossCov`, and use a fast runtime check to determine which shrinkage function to use.
13+
14+
abstract type LossFunction end
15+
16+
"""
17+
NormLossCov(norm::Symbol, pivotidx::Int)
18+
19+
Specify a loss function for which the estimated covariance will be optimal. `norm` is one of
20+
`:L1`, `:L2`, or `:Linf`, and `pivotidx` is an integer from 1 to 7, as specified in Table 1 (p. 1755)
21+
of Donoho et al. (2018). In the table below, `A` and `B` are the target and sample covariances,
22+
respectively, and the loss function is the specified norm on the quantity in the `pivot` column:
23+
24+
| `pivotidx` | `pivot` | Notes |
25+
|------------|---------|-------|
26+
| 1 | `A - B` | |
27+
| 2 | `A⁻¹ - B⁻¹` | |
28+
| 3 | `A⁻¹ B - I` | Not available for `:L1` |
29+
| 4 | `B⁻¹ A - I` | Not available for `:L1` |
30+
| 5 | `A⁻¹ B + B⁻¹ A - 2I` | Not supported |
31+
| 6 | `sqrt(A) \\ B / sqrt(A) - I` | |
32+
| 7 | `log(sqrt(A) \\ B / sqrt(A))` | Not supported |
33+
34+
See also [`StatLossCov`](@ref).
35+
36+
Reference:
37+
Donoho, D.L., Gavish, M. and Johnstone, I.M., 2018.
38+
Optimal shrinkage of eigenvalues in the spiked covariance model. Annals of statistics, 46(4), p.1742.
39+
"""
40+
struct NormLossCov <: LossFunction
41+
# Lᴺᴷ where N is the norm and K is an integer (1 through 7) representing the pivot function
42+
norm::Symbol
43+
pivotidx::Int
44+
45+
function NormLossCov(norm::Symbol, pivotidx::Int)
46+
norm (:L1, :L2, :Linf) || throw(ArgumentError("norm must be :L1, :L2, or :Linf"))
47+
1 <= pivotidx <= 7 || throw(ArgumentError("pivotidx must be from 1 to 7 (see Table 1 in Donoho et al. (2018))"))
48+
return new(norm, pivotidx)
49+
end
50+
end
51+
52+
"""
53+
StatLossCov(mode::Symbol)
54+
55+
Specify a loss function for which the estimated covariance will be optimal. `mode` is one of
56+
`:st`, `:ent`, `:div`, `:aff`, or `:fre`, as specified in Table 2 (p. 1757) of Donoho et al. (2018).
57+
In the table below, `A` and `B` are the target and sample covariances, respectively:
58+
59+
| `mode` | loss | Interpretation |
60+
|--------|---------|-----|
61+
| `:st` | `st(A, B) = tr(A⁻¹ B - I) - log(det(B)/det(A))` | Minimize `2 Dₖₗ(N(0, B)||N(0, A))` where `N` is normal distribution |
62+
| `:ent` | `st(B, A)` | Minimize errors in Mahalanobis distances |
63+
| `:div` | `st(A, B) + st(B, A)` | |
64+
| `:aff` | `0.5 * log(det(A + B) / (2 * sqrt(det(A*B))))` | Minimize Hellinger distance between `N(0, A)` and `N(0, B)` |
65+
| `:fre` | `tr(A + B - 2sqrt(A*B))` | |
66+
"""
67+
struct StatLossCov <: LossFunction
68+
mode::Symbol
69+
70+
function StatLossCov(mode::Symbol)
71+
statlosses = (:st, :ent, :div, :aff, :fre)
72+
73+
mode statlosses || throw(ArgumentError("mode must be among $(statlosses)"))
74+
return new(mode)
75+
end
76+
end
77+
78+
79+
# Implement Table 2, Donoho et al. (2018), p. 1757
80+
81+
function shrinker(loss::NormLossCov, ℓ::Real, c::Real, s::Real)
82+
# See top of file for why these are branches rather than dispatch
83+
norm, pivotidx = loss.norm, loss.pivotidx
84+
pivotidx (5, 7) && throw(ArgumentError("Pivot index $(pivotidx) is not supported, see Table 2 in Donoho et al. 2018"))
85+
if norm == :L2 # Frobenius
86+
return pivotidx == 1 ?* c^2 + s^2 :
87+
pivotidx == 2 ?/ (c^2 +* s^2) :
88+
pivotidx == 3 ? (ℓ * c^2 +^2 * s^2) / (c^2 +^2 * s^2) :
89+
pivotidx == 4 ? (ℓ^2 * c^2 + s^2) / (ℓ * c^2 + s^2) :
90+
#= pivotidx == 6 =# 1 + (ℓ - 1) * c^2 / (c^2 +* s^2)^2
91+
elseif norm == :Linf # Operator
92+
pivotidx (3, 4) && throw(ArgumentError("Pivot index $(pivotidx) is not supported for Linf norm, see Table 2 in Donoho et al. 2018"))
93+
return pivotidx (1, 2) ?:
94+
#= pivotidx == 6 =# 1 + (ℓ - 1) / (c^2 +* s^2)
95+
elseif norm == :L1 # Nuclear
96+
val = pivotidx == 1 ? 1 + (ℓ - 1) * (1 - 2s^2) :
97+
pivotidx == 2 ?/ (c^2 + (2-1)*s^2) :
98+
pivotidx == 3 ?/ (c^2 +^2*s^2) :
99+
pivotidx == 4 ? (ℓ^2*c^2 + s^2) /:
100+
#= pivotidx == 6 =# (ℓ - (ℓ - 1)^2*c^2*s^2) / (c^2 +*s^2)^2
101+
return max(val, 1)
102+
end
103+
throw(ArgumentError("Norm $(norm) is not supported"))
104+
end
105+
106+
function shrinker(loss::StatLossCov, ℓ::Real, c::Real, s::Real)
107+
mode = loss.mode
108+
if mode == :st
109+
return/ (c^2 +* s^2)
110+
elseif mode == :ent
111+
return* c^2 + s^2
112+
elseif mode == :div
113+
return sqrt((ℓ^2 * c^2 +* s^2) / (c^2 +* s^2))
114+
elseif mode == :fre
115+
return (sqrt(ℓ) * c^2 + s^2)^2
116+
elseif mode == :aff
117+
return ((1 + c^2)*+ s^2) / (1 + c^2 +* s^2)
118+
end
119+
throw(ArgumentError("Mode $(mode) is not supported"))
120+
end

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,9 @@ totalweight(_, weights) = sum(weights)
1717
# Dividing by zero produces zero
1818
guardeddiv(num, denom) = iszero(denom) ? zero(num)/oneunit(denom) : num/denom
1919
diaginv(guard::Bool, num, v) = guard ? map(z -> guardeddiv(num, z), v) : num ./ v
20+
21+
function weightedX(X::AbstractMatrix, weights::FrequencyWeights; dims=1)
22+
rootweights = sqrt.(weights)
23+
return dims == 1 ? X .* rootweights : rootweights' .* X
24+
end
25+
weightedX(X::AbstractMatrix; dims=1) = X

src/woodbury.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Covariance estimation for high-dimensional data
2+
# Uses a covariance model of the form `Σ = σ²I + U * Λ * U'`, where `U` is a low-rank matrix of eigenvectors and
3+
# `Λ` is a diagonal matrix capturing the excess width along the dimensions in `U` compared to isotropic.
4+
5+
# If you're curious about dispatch and inferrability, see the "Implementation note" at the top of src/loss.jl.
6+
# We employ the same de-specialization trick even for `σ²` in `WoodburyEstimator`, as we'll adopt the eltype of
7+
# `Λ` anyway.
8+
9+
"""
10+
WoodburyEstimator(loss::LossFunction, rank::Integer;
11+
σ²::Union{Real,Nothing}=nothing, corrected::Bool=false)
12+
13+
Specify that covariance matrices should be estimated using a "spiked" covariance model
14+
15+
Σ = σ²I + U * Λ * U'
16+
17+
`loss` is either a [`NormLossCov`](@ref) or [`StatLossCov`](@ref) object, which specifies the
18+
loss function for which the estimated covariance will be optimal. `rank` is the maximum
19+
number of eigenvalues `Λ` to retain in the model. Optionally, one may specify `σ²` directly,
20+
or it can be estimated from the data matrix (`σ²=nothing`). Set `corrected=true` to use
21+
the unbiased estimator of the variance.
22+
"""
23+
struct WoodburyEstimator{L<:LossFunction} <: CovarianceEstimator
24+
loss::L
25+
rank::Int
26+
σ²::Union{Real,Nothing} # common diagonal variance, `nothing` indicates unknown
27+
corrected::Bool
28+
end
29+
WoodburyEstimator(loss::LossFunction, rank::Integer; σ²::Union{Real,Nothing}=nothing, corrected::Bool=false) =
30+
WoodburyEstimator(loss, rank, σ², corrected)
31+
32+
"""
33+
cov(estimator::WoodburyEstimator, X::AbstractMatrix, weights::FrequencyWeights...; dims::Int=1, mean=nothing, UsV=nothing)
34+
35+
Estimate the covariance matrix from the data matrix `X` using a "spiked" covariance model
36+
37+
Σ = σ²I + U * Λ * U',
38+
39+
where `U` is a low-rank matrix of eigenvectors and `Λ` is a diagonal matrix.
40+
41+
Reference:
42+
Donoho, D.L., Gavish, M. and Johnstone, I.M., 2018.
43+
Optimal shrinkage of eigenvalues in the spiked covariance model. Annals of statistics, 46(4), p.1742.
44+
45+
When `σ²` is not supplied in `estimator`, it is calculated from the residuals `X - X̂`, where `X̂` is the
46+
low-rank approximation of `X` used to generate `U` and `Λ`.
47+
48+
If `X` is too large to manipulate in memory, you can pass `UsV = (U, s, V)` (a truncated SVD of `X - mean(X; dims)`)
49+
and then `X` will only be used compute the dimensionality and number of observations. This requires that you
50+
specify `estimator.σ²`.
51+
"""
52+
function cov(estimator::WoodburyEstimator, X::AbstractMatrix{<:Real}, weights::FrequencyWeights...;
53+
dims::Int=1, mean=nothing, UsV = nothing)
54+
# Argument validation
55+
dims (1, 2) || throw(ArgumentError("Argument dims can only be 1 or 2 (given: $dims)"))
56+
p = size(X, 3 - dims)
57+
p >= estimator.rank || throw(ArgumentError("Argument rank (got $(estimator.rank)) must be less than the number of observations (size(X, dims)=$(size(X, dims)))"))
58+
wn = totalweight(size(X, dims), weights...)
59+
60+
local ΔX
61+
U, s, V = if UsV === nothing
62+
if mean === nothing
63+
mean = Statistics.mean(X, weights...; dims=dims)
64+
end
65+
# Compute the low-rank approximation of the centered data matrix
66+
ΔX = weightedX(X .- mean, weights...; dims=dims)
67+
tsvd(ΔX, estimator.rank)
68+
else
69+
UsV
70+
end
71+
72+
T = eltype(s)
73+
σ² = estimator.σ²
74+
σ² = if σ² === nothing
75+
ΔΔX = ΔX - U*Diagonal(s)*V'
76+
# The number of degrees of freedom is (number of observations minus the rank)*dimensionality
77+
= (totalweight(size(X, dims), weights...) - estimator.rank) * size(X, 3-dims)
78+
sum(abs2, ΔΔX) / (nσ - estimator.corrected)
79+
else
80+
T(σ²)
81+
end::T # fix inferrability (see note at top of file)
82+
83+
# Ratio of dimensionality to number of observations (the principal parameter in Random Matrix Theory)
84+
γ = p / wn
85+
86+
# Implement the optimal shrinkage algorithm
87+
λ_shrunk = shrink.(Ref(estimator.loss), s.^2 ./ wn, σ², γ)
88+
keep = (!iszero).(λ_shrunk)
89+
90+
# Return the shrunk covariance matrix as a WoodburyMatrix
91+
return SymWoodbury(σ² * I(p), dims == 1 ? V[:, keep] : U[:, keep], Diagonal(λ_shrunk[keep]))
92+
end
93+
94+
function shrink(loss::LossFunction, λ::Real, σ²::Real, γ::Real)
95+
# Implement the procedure on Donoho et al. (2018), p. 1758
96+
# We return the difference from σ², since that's already contained in the diagonal term
97+
λu = λ / σ²
98+
λ₊ = (1 + sqrt(γ))^2
99+
λu < λ₊ && return zero(σ²)
100+
# Calculate the "de-biased" eigenvalue ℓ (Eq. 1.10)
101+
λ′ = λu + 1 - γ
102+
= (λ′ + sqrt(λ′^2 - 4λu)) / 2
103+
# Calculate the cosine (Eq. 1.6)
104+
c = sqrt((1 - γ / (ℓ - 1)^2) / (1 + γ / (ℓ - 1)))
105+
# Calculate the sine
106+
s = sqrt(1 - c^2)
107+
# Apply the shrinker
108+
return σ² * (shrinker(loss, ℓ, c, s) - 1)
109+
end

0 commit comments

Comments
 (0)