Skip to content

Commit f123f5f

Browse files
authored
AbstractFloat support (#84)
* Float32 support for linear shrinkages * Float32 support for nonlinear shrinkage * fix unit tests * cast matrix types to float as suggested by @mateuszbaran * bump package version
1 parent 4760963 commit f123f5f

File tree

6 files changed

+72
-38
lines changed

6 files changed

+72
-38
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CovarianceEstimation"
22
uuid = "587fd27a-f159-11e8-2dae-1979310e6154"
33
authors = ["Mateusz Baran <[email protected]>", "Thibaut Lienart"]
4-
version = "0.2.7"
4+
version = "0.2.8"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/linearshrinkage.jl

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,9 @@ function linear_shrinkage(::DiagonalUnitVariance, Xc::AbstractMatrix,
255255
corrected::Bool)
256256

257257
F = I
258+
T = float(eltype(S))
258259
κ = n - Int(corrected)
259-
γ = κ/n
260+
γ = T(κ/n)
260261
Xc² = Xc.^2
261262
# computing the shrinkage
262263
if λ [:auto, :lw]
@@ -265,15 +266,15 @@ function linear_shrinkage(::DiagonalUnitVariance, Xc::AbstractMatrix,
265266
λ /= κ * (ΣS² - 2tr(S) + p)
266267
elseif λ == :ss
267268
# use the standardised data matrix
268-
d = 1.0 ./ vec(sum(Xc², dims=1))
269+
d = one(T) ./ vec(sum(Xc², dims=1))
269270
= rescale(S, sqrt.(d)) # this has diagonal 1/κ
270271
ΣS̄² = sumij2(S̄, with_diag=true)
271272
λ = sumij(rescale!(uccov(Xc²), d), with_diag=true) / γ^2 - ΣS̄²
272-
λ /= κ * ΣS̄² - p / κ
273+
λ /= T(κ * ΣS̄² - p / κ)
273274
else
274275
throw(ArgumentError("Unsupported shrinkage method for target DiagonalUnitVariance: ."))
275276
end
276-
λ = clamp(λ, 0.0, 1.0)
277+
λ = clamp(λ, zero(T), one(T))
277278
return linshrink(F, S, λ)
278279
end
279280

@@ -298,8 +299,9 @@ function linear_shrinkage(::DiagonalCommonVariance, Xc::AbstractMatrix,
298299
corrected::Bool)
299300

300301
F = target_B(S, p)
302+
T = float(eltype(F))
301303
κ = n - Int(corrected)
302-
γ = κ/n
304+
γ = T(κ/n)
303305
Xc² = Xc.^2
304306
# computing the shrinkage
305307
if λ [:auto, :lw]
@@ -309,28 +311,28 @@ function linear_shrinkage(::DiagonalCommonVariance, Xc::AbstractMatrix,
309311
λ /= κ * (ΣS² - p*v^2)
310312
elseif λ == :ss
311313
# use the standardised data matrix
312-
d = 1.0 ./ vec(sum(Xc², dims=1))
314+
d = one(T) ./ vec(sum(Xc², dims=1))
313315
= rescale(S, sqrt.(d)) # this has diagonal 1/κ
314316
= κ # tr(S̄)/p
315317
ΣS̄² = sumij2(S̄, with_diag=true)
316318
λ = sumij(rescale!(uccov(Xc²), d), with_diag=true) / γ^2 - ΣS̄²
317-
λ /= κ * ΣS̄² - p/κ
319+
λ /= T(κ * ΣS̄² - p/κ)
318320
elseif λ == :rblw
319321
# https://arxiv.org/pdf/0907.4698.pdf equations 17, 19
320322
trS² = sum(abs2, S)
321323
tr²S = tr(S)^2
322324
# note: using corrected or uncorrected S does not change λ
323-
λ = ((n-2)/n * trS² + tr²S) / ((n+2) * (trS² - tr²S/p))
325+
λ = T(((n-2)/n * trS² + tr²S) / ((n+2) * (trS² - tr²S/p)))
324326
elseif λ == :oas
325327
# https://arxiv.org/pdf/0907.4698.pdf equation 23
326328
trS² = sum(abs2, S)
327329
tr²S = tr(S)^2
328330
# note: using corrected or uncorrected S does not change λ
329-
λ = ((1.0-2.0/p) * trS² + tr²S) / ((n+1.0-2.0/p) * (trS² - tr²S/p))
331+
λ = ((one(T)-T(2.0)/p) * trS² + tr²S) / ((n+one(T)-T(2.0)/p) * (trS² - tr²S/p))
330332
else
331333
throw(ArgumentError("Unsupported shrinkage method for target DiagonalCommonVariance: ."))
332334
end
333-
λ = clamp(λ, 0.0, 1.0)
335+
λ = clamp(λ, zero(T), one(T))
334336
return linshrink(F, S, λ)
335337
end
336338

@@ -355,8 +357,9 @@ function linear_shrinkage(::DiagonalUnequalVariance, Xc::AbstractMatrix,
355357
corrected::Bool)
356358

357359
F = target_D(S)
360+
T = float(eltype(F))
358361
κ = n - Int(corrected)
359-
γ = κ/n
362+
γ = T/ n)
360363
Xc² = Xc.^2
361364
# computing the shrinkage
362365
if λ [:auto, :lw]
@@ -365,14 +368,14 @@ function linear_shrinkage(::DiagonalUnequalVariance, Xc::AbstractMatrix,
365368
λ /= κ * ΣS²
366369
elseif λ == :ss
367370
# use the standardised data matrix
368-
d = 1.0 ./ vec(sum(Xc², dims=1))
371+
d = one(T) ./ vec(sum(Xc², dims=1))
369372
ΣS̄² = sumij2(rescale(S, sqrt.(d)))
370373
λ = sumij(rescale!(uccov(Xc²), d)) / γ^2 - ΣS̄²
371374
λ /= κ * ΣS̄²
372375
else
373376
throw(ArgumentError("Unsupported shrinkage method for target DiagonalUnequalVariance: ."))
374377
end
375-
λ = clamp(λ, 0.0, 1.0)
378+
λ = clamp(λ, zero(T), one(T))
376379
return linshrink(F, S, λ)
377380
end
378381

@@ -405,24 +408,25 @@ function linear_shrinkage(::CommonCovariance, Xc::AbstractMatrix,
405408
corrected::Bool)
406409

407410
F, v, c = target_C(S, p)
411+
T = float(eltype(F))
408412
κ = n - Int(corrected)
409-
γ = κ/n
413+
γ = T(κ/n)
410414
Xc² = Xc.^2
411415
# computing the shrinkage
412416
if λ [:auto, :lw]
413417
ΣS² = sumij2(S, with_diag=true)
414418
λ = sumij(uccov(Xc²), with_diag=true) / γ^2 - ΣS²
415419
λ /= κ * (ΣS² - p*(p-1)*c^2 - p*v^2)
416420
elseif λ == :ss
417-
d = 1.0 ./ vec(sum(Xc², dims=1))
421+
d = one(T) ./ vec(sum(Xc², dims=1))
418422
= rescale(S, sqrt.(d))
419423
ΣS̄² = sumij2(S̄, with_diag=true)
420424
λ = sumij(rescale!(uccov(Xc²), d), with_diag=true) / γ^2 - ΣS̄²
421425
λ /= κ * ΣS̄² - p/κ - κ * sumij(S̄; with_diag=false)^2 / (p * (p - 1))
422426
else
423427
throw(ArgumentError("Unsupported shrinkage method for target CommonCovariance: ."))
424428
end
425-
λ = clamp(λ, 0.0, 1.0)
429+
λ = clamp(λ, zero(T), one(T))
426430
return linshrink!(F, S, λ)
427431
end
428432

@@ -451,8 +455,9 @@ function linear_shrinkage(::PerfectPositiveCorrelation, Xc::AbstractMatrix,
451455
corrected::Bool)
452456

453457
F = target_E(S)
458+
T = float(eltype(F))
454459
κ = n - Int(corrected)
455-
γ = κ/n
460+
γ = T(κ/n)
456461
Xc² = Xc.^2
457462
# computing the shrinkage
458463
if λ [:auto, :lw]
@@ -461,7 +466,7 @@ function linear_shrinkage(::PerfectPositiveCorrelation, Xc::AbstractMatrix,
461466
λ -= sum_fij(Xc, S, n, κ)
462467
λ /= sumij2(S - F)
463468
elseif λ == :ss
464-
d = 1.0 ./ vec(sum(Xc², dims=1))
469+
d = one(T) ./ vec(sum(Xc², dims=1))
465470
s = sqrt.(d)
466471
= rescale(S, s)
467472
ΣS̄² = sumij2(S̄)
@@ -472,7 +477,7 @@ function linear_shrinkage(::PerfectPositiveCorrelation, Xc::AbstractMatrix,
472477
else
473478
throw(ArgumentError("Unsupported shrinkage method for target PerfectPositiveCorrelation: ."))
474479
end
475-
λ = clamp(λ, 0.0, 1.0)
480+
λ = clamp(λ, zero(T), one(T))
476481
return linshrink!(F, S, λ)
477482
end
478483

@@ -505,8 +510,9 @@ function linear_shrinkage(::ConstantCorrelation, Xc::AbstractMatrix,
505510
corrected::Bool)
506511

507512
F, r̄ = target_F(S, p)
513+
T = float(eltype(F))
508514
κ = n - Int(corrected)
509-
γ = κ/n
515+
γ = T(κ/n)
510516
Xc² = Xc.^2
511517
# computing the shrinkage
512518
if λ [:auto, :lw]
@@ -515,7 +521,7 @@ function linear_shrinkage(::ConstantCorrelation, Xc::AbstractMatrix,
515521
λ -=* sum_fij(Xc, S, n, κ)
516522
λ /= sumij2(S - F)
517523
elseif λ == :ss
518-
d = 1.0 ./ vec(sum(Xc², dims=1))
524+
d = one(T) ./ vec(sum(Xc², dims=1))
519525
s = sqrt.(d)
520526
= rescale(S, s)
521527
F̄, r̄ = target_F(S̄, p)
@@ -526,6 +532,6 @@ function linear_shrinkage(::ConstantCorrelation, Xc::AbstractMatrix,
526532
else
527533
throw(ArgumentError("Unsupported shrinkage method for target ConstantCorrelation: ."))
528534
end
529-
λ = clamp(λ, 0.0, 1.0)
535+
λ = clamp(λ, zero(T), one(T))
530536
return linshrink!(F, S, λ)
531537
end

src/nonlinearshrinkage.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,24 @@ const EPAN_3 = 0.3 * INVPI
2727
2828
Return the Epanechnikov kernel evaluated at `x`.
2929
"""
30-
epanechnikov(x::Real) = EPAN_1 * max(0.0, 1.0 - x^2/5.0)
30+
epanechnikov(x::T) where T<:Real = float(T)(EPAN_1 * max(0.0, 1.0 - x^2/5.0))
3131

3232
"""
3333
epnanechnikov_HT(x)
3434
3535
Return the Hilbert Transform of the Epanechnikov kernel evaluated at `x`
3636
if `|x|≂̸√5`.
3737
"""
38-
function epanechnikov_HT1(x::Real)
39-
-EPAN_3 * x + EPAN_2 * (1.0 - x^2/5.0) * log(abs((SQRT5 - x)/(SQRT5 + x)))
38+
function epanechnikov_HT1(x::T) where T <: Real
39+
float(T)(-EPAN_3 * x + EPAN_2 * (1.0 - x^2/5.0) * log(abs((SQRT5 - x)/(SQRT5 + x))))
4040
end
4141

4242
"""
4343
epnanechnikov_HT2(x)
4444
Return the Hilbert Transform of the Epanechnikov kernel evaluated at `x`
4545
if `|x|=√5`.
4646
"""
47-
epanechnikov_HT2(x::Real) = -EPAN_3*x
47+
epanechnikov_HT2(x::T) where T <: Real = float(T)(-EPAN_3*x)
4848

4949
"""
5050
analytical_nonlinear_shrinkage(S, n, p; decomp)
@@ -66,19 +66,20 @@ function analytical_nonlinear_shrinkage(S::AbstractMatrix{<:Real},
6666
sample_perm = @view perm[max(1, (p - η) + 1):p]
6767
λ = @view F.values[sample_perm]
6868
U = F.vectors[:, perm]
69+
T = float(eltype(F))
6970

7071
# dominant cost forming of S or eigen(S) --> O(max{np^2, p^3})
7172

7273
# compute analytical nonlinear shrinkage kernel formula
7374
L = repeat(λ, outer=(1, min(p, η)))
7475

7576
# Equation (4.9)
76-
h = η^(-1//3)
77+
h = T(η^(-1//3))
7778
H = h * L'
7879
x = (L .- L') ./ H
7980

8081
# additional useful definitions
81-
γ = p/η
82+
γ = T(p/η)
8283
πλ = π * λ
8384

8485
# Equation (4.7) in http://www.econ.uzh.ch/static/wp/econwp264.pdf
@@ -96,17 +97,17 @@ function analytical_nonlinear_shrinkage(S::AbstractMatrix{<:Real},
9697
if p <= η
9798
# Equation (4.3)
9899
πγλ = γ * πλ
99-
denom = @. (πγλ * f̃)^2 + (1.0 - γ - πγλ * Hf̃)^2
100+
denom = @. (πγλ * f̃)^2 + (one(T) - γ - πγλ * Hf̃)^2
100101
= λ ./ denom
101102
else
102103
# Equation (C.8)
103-
hs5 = h * SQRT5
104-
Hf̃0 = (0.3/h^2 + 0.75/hs5 * (1.0 - 0.2/h^2) * log((1+hs5)/(1-hs5)))
105-
Hf̃0 *= mean(1.0 ./ πλ)
104+
hs5 = T(h * SQRT5)
105+
Hf̃0 = T((0.3/h^2 + 0.75/hs5 * (1.0 - 0.2/h^2) * log((1+hs5)/(1-hs5))))
106+
Hf̃0 *= mean(one(T) ./ πλ)
106107
# Equation (C.5)
107-
d̃0 = INVPI / ((γ - 1.0) * Hf̃0)
108+
d̃0 = T(INVPI / ((γ - one(T)) * Hf̃0))
108109
# Eq. (C.4)
109-
d̃1 = @. 1.0 /* πλ * (f̃^2 + Hf̃^2))
110+
d̃1 = @. one(T) /* πλ * (f̃^2 + Hf̃^2))
110111
= [fill(d̃0, (p - η, 1)); d̃1]
111112
end
112113

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
function linshrink(F::AbstractMatrix, S::AbstractMatrix, λ::Real)
2-
return Symmetric((1.0 .- λ).*S .+ λ.*F)
2+
return Symmetric((one(λ) .- λ).*S .+ λ.*F)
33
end
44

55
function linshrink(F::UniformScaling, S::AbstractMatrix, λ::Real)
6-
return Symmetric((1.0 .- λ).*S + λ.*F)
6+
return Symmetric((one(λ) .- λ).*S + λ.*F)
77
end
88

99
function linshrink!(F::AbstractMatrix, S::AbstractMatrix, λ::Real)
10-
F .= (1.0 .- λ).*S .+ λ.*F
10+
F .= (one(λ) .- λ).*S .+ λ.*F
1111
return Symmetric(F)
1212
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,5 @@ end
7575
include("test_biweight.jl")
7676
include("test_linearshrinkage.jl")
7777
include("test_nonlinearshrinkage.jl")
78+
include("test_float32.jl")
7879
end

test/test_float32.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
@testset "Float32 matrices" begin
2+
# linear shrinkages
3+
for X in test_matrices
4+
x = convert(Matrix{Float32}, X)
5+
for target in [
6+
DiagonalUnitVariance(),
7+
DiagonalCommonVariance(),
8+
DiagonalUnequalVariance(),
9+
CommonCovariance(),
10+
PerfectPositiveCorrelation(),
11+
ConstantCorrelation()
12+
]
13+
for shrinkage in [:lw, :ss]
14+
@test eltype(cov(LinearShrinkage(target, shrinkage), x)) == Float32
15+
end
16+
end
17+
end
18+
19+
# nonlinear shrinkages
20+
ANS = AnalyticalNonlinearShrinkage()
21+
for X in test_matrices
22+
size(X, 1) < 12 && continue
23+
x = convert(Matrix{Float32}, X)
24+
@test eltype(cov(ANS, x)) == Float32
25+
end
26+
end

0 commit comments

Comments
 (0)