Skip to content

Commit a4a9808

Browse files
fix up tests
1 parent 84dc2cf commit a4a9808

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

docs/src/tutorials/ensemble_modeling.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ data_train = [
142142
R => (t_train,fullR[1:15]),
143143
]
144144
t_ensem = 0:21
145-
data_train = [
145+
data_ensem = [
146146
S => (t_ensem,fullS[1:22]),
147147
I => (t_ensem,fullI[1:22]),
148148
R => (t_ensem,fullR[1:22]),

test/ensemble.jl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,33 +56,39 @@ sol = solve(enprob; saveat = 1);
5656

5757
weights = [0.2, 0.5, 0.3]
5858

59+
fullS = vec(sum(stack(weights .* sol[:,S]),dims=2))
60+
fullI = vec(sum(stack(weights .* sol[:,I]),dims=2))
61+
fullR = vec(sum(stack(weights .* sol[:,R]),dims=2))
62+
5963
t_train = 0:14
6064
data_train = [
61-
S => (t_train,vec(sum(stack([weights[i] * sol[i][S][1:15] for i in 1:3]), dims = 2))),
62-
I => (t_train,vec(sum(stack([weights[i] * sol[i][I][1:15] for i in 1:3]), dims = 2))),
63-
R => (t_train,vec(sum(stack([weights[i] * sol[i][R][1:15] for i in 1:3]), dims = 2))),
65+
S => (t_train,fullS[1:15]),
66+
I => (t_train,fullI[1:15]),
67+
R => (t_train,fullR[1:15]),
6468
]
6569
t_ensem = 0:21
6670
data_ensem = [
67-
S => (t_ensem,vec(sum(stack([weights[i] * sol[i][S][1:22] for i in 1:3]), dims = 2))),
68-
I => (t_ensem,vec(sum(stack([weights[i] * sol[i][I][1:22] for i in 1:3]), dims = 2))),
69-
R => (t_ensem,vec(sum(stack([weights[i] * sol[i][R][1:22] for i in 1:3]), dims = 2))),
71+
S => (t_ensem,fullS[1:22]),
72+
I => (t_ensem,fullI[1:22]),
73+
R => (t_ensem,fullR[1:22]),
7074
]
7175
t_forecast = 0:30
7276
data_forecast = [
73-
S => (t_forecast,vec(sum(stack([weights[i] * sol[i][S][1:end] for i in 1:3]), dims = 2))),
74-
I => (t_forecast,vec(sum(stack([weights[i] * sol[i][I][1:end] for i in 1:3]), dims = 2))),
75-
R => (t_forecast,vec(sum(stack([weights[i] * sol[i][R][1:end] for i in 1:3]), dims = 2))),
77+
S => (t_forecast,fullS),
78+
I => (t_forecast,fullI),
79+
R => (t_forecast,fullR),
7680
]
7781

7882
sol = solve(enprob; saveat = t_ensem);
7983

8084
@test ensemble_weights(sol, data_ensem) [0.2, 0.5, 0.3]
8185

82-
probs = [prob, prob2, prob3]
83-
ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3]
84-
datas = [data_train,data_train,data_train]
86+
probs = (prob, prob2, prob3)
87+
ps = Tuple([β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3)
88+
datas = (data_train,data_train,data_train)
8589
enprobs = bayesian_ensemble(probs, ps, datas)
8690

8791
sol = solve(enprobs; saveat = t_ensem);
88-
ensemble_weights(sol, data_ensem)
92+
ensemble_weights(sol, data_ensem)
93+
94+
bayesian_datafit(probs, ps, datas)

0 commit comments

Comments
 (0)