|
| 1 | +# Nonlinear shrinkage estimation can be described in terms of a loss function measuring the error between |
| 2 | +# the target and the sample eigenvalues. In: |
| 3 | +# Donoho, D.L., Gavish, M. and Johnstone, I.M., 2018. |
| 4 | +# Optimal shrinkage of eigenvalues in the spiked covariance model. Annals of statistics, 46(4), p.1742. |
| 5 | +# there is a systematic analysis of different loss functions; their classification scheme is encoded here. |
| 6 | + |
| 7 | +# Implementation note: |
| 8 | +# While we could parametrize all these different loss functions in the type system and use dispatch, |
| 9 | +# that would induce needless specialization: every function that took a `LossFunction` would have to be |
| 10 | +# specialized, even if the the argument encoding the loss function is merely "passed-through." |
| 11 | +# So instead, we hide details from the type system, dividing the Donoho classification scheme into just two types, |
| 12 | +# `NormLossCov` and `StatLossCov`, and use a fast runtime check to determine which shrinkage function to use. |
| 13 | + |
| 14 | +abstract type LossFunction end |
| 15 | + |
| 16 | +""" |
| 17 | + NormLossCov(norm::Symbol, pivotidx::Int) |
| 18 | +
|
| 19 | +Specify a loss function for which the estimated covariance will be optimal. `norm` is one of |
| 20 | +`:L1`, `:L2`, or `:Linf`, and `pivotidx` is an integer from 1 to 7, as specified in Table 1 (p. 1755) |
| 21 | +of Donoho et al. (2018). In the table below, `A` and `B` are the target and sample covariances, |
| 22 | +respectively, and the loss function is the specified norm on the quantity in the `pivot` column: |
| 23 | +
|
| 24 | +| `pivotidx` | `pivot` | Notes | |
| 25 | +|------------|---------|-------| |
| 26 | +| 1 | `A - B` | | |
| 27 | +| 2 | `A⁻¹ - B⁻¹` | | |
| 28 | +| 3 | `A⁻¹ B - I` | Not available for `:L1` | |
| 29 | +| 4 | `B⁻¹ A - I` | Not available for `:L1` | |
| 30 | +| 5 | `A⁻¹ B + B⁻¹ A - 2I` | Not supported | |
| 31 | +| 6 | `sqrt(A) \\ B / sqrt(A) - I` | | |
| 32 | +| 7 | `log(sqrt(A) \\ B / sqrt(A))` | Not supported | |
| 33 | +
|
| 34 | +See also [`StatLossCov`](@ref). |
| 35 | +
|
| 36 | +Reference: |
| 37 | + Donoho, D.L., Gavish, M. and Johnstone, I.M., 2018. |
| 38 | + Optimal shrinkage of eigenvalues in the spiked covariance model. Annals of statistics, 46(4), p.1742. |
| 39 | +""" |
| 40 | +struct NormLossCov <: LossFunction |
| 41 | + # Lᴺᴷ where N is the norm and K is an integer (1 through 7) representing the pivot function |
| 42 | + norm::Symbol |
| 43 | + pivotidx::Int |
| 44 | + |
| 45 | + function NormLossCov(norm::Symbol, pivotidx::Int) |
| 46 | + norm ∈ (:L1, :L2, :Linf) || throw(ArgumentError("norm must be :L1, :L2, or :Linf")) |
| 47 | + 1 <= pivotidx <= 7 || throw(ArgumentError("pivotidx must be from 1 to 7 (see Table 1 in Donoho et al. (2018))")) |
| 48 | + return new(norm, pivotidx) |
| 49 | + end |
| 50 | +end |
| 51 | + |
| 52 | +""" |
| 53 | + StatLossCov(mode::Symbol) |
| 54 | +
|
| 55 | +Specify a loss function for which the estimated covariance will be optimal. `mode` is one of |
| 56 | +`:st`, `:ent`, `:div`, `:aff`, or `:fre`, as specified in Table 2 (p. 1757) of Donoho et al. (2018). |
| 57 | +In the table below, `A` and `B` are the target and sample covariances, respectively: |
| 58 | +
|
| 59 | +| `mode` | loss | Interpretation | |
| 60 | +|--------|---------|-----| |
| 61 | +| `:st` | `st(A, B) = tr(A⁻¹ B - I) - log(det(B)/det(A))` | Minimize `2 Dₖₗ(N(0, B)||N(0, A))` where `N` is normal distribution | |
| 62 | +| `:ent` | `st(B, A)` | Minimize errors in Mahalanobis distances | |
| 63 | +| `:div` | `st(A, B) + st(B, A)` | | |
| 64 | +| `:aff` | `0.5 * log(det(A + B) / (2 * sqrt(det(A*B))))` | Minimize Hellinger distance between `N(0, A)` and `N(0, B)` | |
| 65 | +| `:fre` | `tr(A + B - 2sqrt(A*B))` | | |
| 66 | +""" |
| 67 | +struct StatLossCov <: LossFunction |
| 68 | + mode::Symbol |
| 69 | + |
| 70 | + function StatLossCov(mode::Symbol) |
| 71 | + statlosses = (:st, :ent, :div, :aff, :fre) |
| 72 | + |
| 73 | + mode ∈ statlosses || throw(ArgumentError("mode must be among $(statlosses)")) |
| 74 | + return new(mode) |
| 75 | + end |
| 76 | +end |
| 77 | + |
| 78 | + |
| 79 | +# Implement Table 2, Donoho et al. (2018), p. 1757 |
| 80 | + |
| 81 | +function shrinker(loss::NormLossCov, ℓ::Real, c::Real, s::Real) |
| 82 | + # See top of file for why these are branches rather than dispatch |
| 83 | + norm, pivotidx = loss.norm, loss.pivotidx |
| 84 | + pivotidx ∈ (5, 7) && throw(ArgumentError("Pivot index $(pivotidx) is not supported, see Table 2 in Donoho et al. 2018")) |
| 85 | + if norm == :L2 # Frobenius |
| 86 | + return pivotidx == 1 ? ℓ * c^2 + s^2 : |
| 87 | + pivotidx == 2 ? ℓ / (c^2 + ℓ * s^2) : |
| 88 | + pivotidx == 3 ? (ℓ * c^2 + ℓ^2 * s^2) / (c^2 + ℓ^2 * s^2) : |
| 89 | + pivotidx == 4 ? (ℓ^2 * c^2 + s^2) / (ℓ * c^2 + s^2) : |
| 90 | + #= pivotidx == 6 =# 1 + (ℓ - 1) * c^2 / (c^2 + ℓ * s^2)^2 |
| 91 | + elseif norm == :Linf # Operator |
| 92 | + pivotidx ∈ (3, 4) && throw(ArgumentError("Pivot index $(pivotidx) is not supported for Linf norm, see Table 2 in Donoho et al. 2018")) |
| 93 | + return pivotidx ∈ (1, 2) ? ℓ : |
| 94 | + #= pivotidx == 6 =# 1 + (ℓ - 1) / (c^2 + ℓ * s^2) |
| 95 | + elseif norm == :L1 # Nuclear |
| 96 | + val = pivotidx == 1 ? 1 + (ℓ - 1) * (1 - 2s^2) : |
| 97 | + pivotidx == 2 ? ℓ / (c^2 + (2ℓ-1)*s^2) : |
| 98 | + pivotidx == 3 ? ℓ / (c^2 + ℓ^2*s^2) : |
| 99 | + pivotidx == 4 ? (ℓ^2*c^2 + s^2) / ℓ : |
| 100 | + #= pivotidx == 6 =# (ℓ - (ℓ - 1)^2*c^2*s^2) / (c^2 + ℓ*s^2)^2 |
| 101 | + return max(val, 1) |
| 102 | + end |
| 103 | + throw(ArgumentError("Norm $(norm) is not supported")) |
| 104 | +end |
| 105 | + |
| 106 | +function shrinker(loss::StatLossCov, ℓ::Real, c::Real, s::Real) |
| 107 | + mode = loss.mode |
| 108 | + if mode == :st |
| 109 | + return ℓ / (c^2 + ℓ * s^2) |
| 110 | + elseif mode == :ent |
| 111 | + return ℓ * c^2 + s^2 |
| 112 | + elseif mode == :div |
| 113 | + return sqrt((ℓ^2 * c^2 + ℓ * s^2) / (c^2 + ℓ * s^2)) |
| 114 | + elseif mode == :fre |
| 115 | + return (sqrt(ℓ) * c^2 + s^2)^2 |
| 116 | + elseif mode == :aff |
| 117 | + return ((1 + c^2)*ℓ + s^2) / (1 + c^2 + ℓ * s^2) |
| 118 | + end |
| 119 | + throw(ArgumentError("Mode $(mode) is not supported")) |
| 120 | +end |
0 commit comments