Skip to content

Commit f23eb7a

Browse files
committed
fix move benchmark model to main file
1 parent 2a1755a commit f23eb7a

File tree

2 files changed

+28
-34
lines changed

2 files changed

+28
-34
lines changed

bench/benchmarks.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,34 @@ BLAS.set_num_threads(min(4, Threads.nthreads()))
1616
@info sprint(versioninfo)
1717
@info "BLAS threads: $(BLAS.get_num_threads())"
1818

19-
include("unconstrdist.jl")
19+
struct Dist{D<:ContinuousMultivariateDistribution}
20+
dist::D
21+
end
22+
23+
function LogDensityProblems.logdensity(model::Dist, x)
24+
return logpdf(model.dist, x)
25+
end
26+
27+
function LogDensityProblems.logdensity_and_gradient(model::Dist, θ)
28+
return (
29+
LogDensityProblems.logdensity(model, θ),
30+
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
31+
)
32+
end
33+
34+
function LogDensityProblems.dimension(model::Dist)
35+
return length(model.dist)
36+
end
37+
38+
function LogDensityProblems.capabilities(::Type{<:Dist})
39+
return LogDensityProblems.LogDensityOrder{0}()
40+
end
41+
42+
function normal(; n_dims=10, realtype=Float64)
43+
μ = fill(realtype(5), n_dims)
44+
Σ = Diagonal(ones(realtype, n_dims))
45+
return Dist(MvNormal(μ, Σ))
46+
end
2047

2148
const SUITES = BenchmarkGroup()
2249

bench/unconstrdist.jl

Lines changed: 0 additions & 33 deletions
This file was deleted.

0 commit comments

Comments
 (0)