|
43 | 43 | function loss_single_shooting(p) |
44 | 44 | pred = predict_single_shooting(p) |
45 | 45 | l = loss_function(ode_data, pred) |
46 | | - return l, pred |
| 46 | + return l |
47 | 47 | end |
48 | 48 |
|
49 | 49 | adtype = Optimization.AutoZygote() |
50 | 50 | optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype) |
51 | 51 | optprob = Optimization.OptimizationProblem(optf, p_init) |
52 | 52 | res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
53 | 53 |
|
54 | | - loss_ss, _ = loss_single_shooting(res_single_shooting.minimizer) |
| 54 | + loss_ss = loss_single_shooting(res_single_shooting.minimizer) |
55 | 55 | @info "Single shooting loss: $(loss_ss)" |
56 | 56 |
|
57 | 57 | ## Test Multiple Shooting |
|
60 | 60 |
|
61 | 61 | function loss_multiple_shooting(p) |
62 | 62 | return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(), |
63 | | - group_size; continuity_term, abstol = 1e-8, reltol = 1e-6) # test solver kwargs |
| 63 | + group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs |
64 | 64 | end |
65 | 65 |
|
66 | 66 | adtype = Optimization.AutoZygote() |
|
69 | 69 | res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
70 | 70 |
|
71 | 71 | # Calculate single shooting loss with parameter from multiple_shoot training |
72 | | - loss_ms, _ = loss_single_shooting(res_ms.minimizer) |
| 72 | + loss_ms = loss_single_shooting(res_ms.minimizer) |
73 | 73 | println("Multiple shooting loss: $(loss_ms)") |
74 | 74 | @test loss_ms < 10loss_ss |
75 | 75 |
|
|
83 | 83 |
|
84 | 84 | function loss_multiple_shooting_abs2(p) |
85 | 85 | return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, |
86 | | - continuity_loss_abs2, Tsit5(), group_size; continuity_term) |
| 86 | + continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1] |
87 | 87 | end |
88 | 88 |
|
89 | 89 | adtype = Optimization.AutoZygote() |
|
92 | 92 | optprob = Optimization.OptimizationProblem(optf, p_init) |
93 | 93 | res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
94 | 94 |
|
95 | | - loss_ms_abs2, _ = loss_single_shooting(res_ms_abs2.minimizer) |
| 95 | + loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer) |
96 | 96 | println("Multiple shooting loss with abs2: $(loss_ms_abs2)") |
97 | 97 | @test loss_ms_abs2 < loss_ss |
98 | 98 |
|
99 | 99 | ## Test different SensitivityAlgorithm (default is InterpolatingAdjoint) |
100 | 100 | function loss_multiple_shooting_fd(p) |
101 | 101 | return multiple_shoot( |
102 | 102 | p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2, |
103 | | - Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity()) |
| 103 | + Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1] |
104 | 104 | end |
105 | 105 |
|
106 | 106 | adtype = Optimization.AutoZygote() |
|
109 | 109 | res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
110 | 110 |
|
111 | 111 | # Calculate single shooting loss with parameter from multiple_shoot training |
112 | | - loss_ms_fd, _ = loss_single_shooting(res_ms_fd.minimizer) |
| 112 | + loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer) |
113 | 113 | println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)") |
114 | 114 | @test loss_ms_fd < 10loss_ss |
115 | 115 |
|
116 | 116 | # Integration return codes `!= :Success` should return infinite loss. |
117 | 117 | # In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`. |
118 | | - loss_fail, _ = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function, |
119 | | - Tsit5(), datasize; maxiters = 1, verbose = false) |
| 118 | + loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function, |
| 119 | + Tsit5(), datasize; maxiters = 1, verbose = false)[1] |
120 | 120 | @test loss_fail == Inf |
121 | 121 |
|
122 | 122 | ## Test for DomainErrors |
|
142 | 142 | function loss_multiple_shooting_ens(p) |
143 | 143 | return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, |
144 | 144 | loss_function, Tsit5(), group_size; continuity_term, |
145 | | - trajectories, abstol = 1e-8, reltol = 1e-6) # test solver kwargs |
| 145 | + trajectories, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs |
146 | 146 | end |
147 | 147 |
|
148 | 148 | adtype = Optimization.AutoZygote() |
|
151 | 151 | optprob = Optimization.OptimizationProblem(optf, p_init) |
152 | 152 | res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
153 | 153 |
|
154 | | - loss_ms_ensembles, _ = loss_single_shooting(res_ms_ensembles.minimizer) |
| 154 | + loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer) |
155 | 155 |
|
156 | 156 | println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)") |
157 | 157 |
|
|
0 commit comments