Skip to content

Commit 24d6582

Browse files
Merge pull request #965 from SciML/ChrisRackauckas-patch-1
Fix tests for optimization bump
2 parents 0c26379 + ff512bd commit 24d6582

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

docs/src/examples/tensor_layer.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ ẍ = - kx - αx³ - βẋ -γẋ³.
1010
```
1111

1212
We first transform this second order differential equation into a system of first order
13-
differential equations for use in `DiffEqFlux`: We let `ẋ = v` then
13+
differential equations for use in `DiffEqFlux`: We let `ẋ = v` then
14+
1415
```math
1516
ẋ = v \\
1617
v̇ = - kx - αx³ - βv̇ -γv̇³.

test/multiple_shoot_tests.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@
4343
function loss_single_shooting(p)
4444
pred = predict_single_shooting(p)
4545
l = loss_function(ode_data, pred)
46-
return l, pred
46+
return l
4747
end
4848

4949
adtype = Optimization.AutoZygote()
5050
optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype)
5151
optprob = Optimization.OptimizationProblem(optf, p_init)
5252
res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
5353

54-
loss_ss, _ = loss_single_shooting(res_single_shooting.minimizer)
54+
loss_ss = loss_single_shooting(res_single_shooting.minimizer)
5555
@info "Single shooting loss: $(loss_ss)"
5656

5757
## Test Multiple Shooting
@@ -60,7 +60,7 @@
6060

6161
function loss_multiple_shooting(p)
6262
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
63-
group_size; continuity_term, abstol = 1e-8, reltol = 1e-6) # test solver kwargs
63+
group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs
6464
end
6565

6666
adtype = Optimization.AutoZygote()
@@ -69,7 +69,7 @@
6969
res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
7070

7171
# Calculate single shooting loss with parameter from multiple_shoot training
72-
loss_ms, _ = loss_single_shooting(res_ms.minimizer)
72+
loss_ms = loss_single_shooting(res_ms.minimizer)
7373
println("Multiple shooting loss: $(loss_ms)")
7474
@test loss_ms < 10loss_ss
7575

@@ -83,7 +83,7 @@
8383

8484
function loss_multiple_shooting_abs2(p)
8585
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function,
86-
continuity_loss_abs2, Tsit5(), group_size; continuity_term)
86+
continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1]
8787
end
8888

8989
adtype = Optimization.AutoZygote()
@@ -92,15 +92,15 @@
9292
optprob = Optimization.OptimizationProblem(optf, p_init)
9393
res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
9494

95-
loss_ms_abs2, _ = loss_single_shooting(res_ms_abs2.minimizer)
95+
loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer)
9696
println("Multiple shooting loss with abs2: $(loss_ms_abs2)")
9797
@test loss_ms_abs2 < loss_ss
9898

9999
## Test different SensitivityAlgorithm (default is InterpolatingAdjoint)
100100
function loss_multiple_shooting_fd(p)
101101
return multiple_shoot(
102102
p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2,
103-
Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())
103+
Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1]
104104
end
105105

106106
adtype = Optimization.AutoZygote()
@@ -109,14 +109,14 @@
109109
res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
110110

111111
# Calculate single shooting loss with parameter from multiple_shoot training
112-
loss_ms_fd, _ = loss_single_shooting(res_ms_fd.minimizer)
112+
loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer)
113113
println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)")
114114
@test loss_ms_fd < 10loss_ss
115115

116116
# Integration return codes `!= :Success` should return infinite loss.
117117
# In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`.
118-
loss_fail, _ = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function,
119-
Tsit5(), datasize; maxiters = 1, verbose = false)
118+
loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function,
119+
Tsit5(), datasize; maxiters = 1, verbose = false)[1]
120120
@test loss_fail == Inf
121121

122122
## Test for DomainErrors
@@ -142,7 +142,7 @@
142142
function loss_multiple_shooting_ens(p)
143143
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
144144
loss_function, Tsit5(), group_size; continuity_term,
145-
trajectories, abstol = 1e-8, reltol = 1e-6) # test solver kwargs
145+
trajectories, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs
146146
end
147147

148148
adtype = Optimization.AutoZygote()
@@ -151,7 +151,7 @@
151151
optprob = Optimization.OptimizationProblem(optf, p_init)
152152
res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
153153

154-
loss_ms_ensembles, _ = loss_single_shooting(res_ms_ensembles.minimizer)
154+
loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer)
155155

156156
println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)")
157157

test/second_order_ode_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525

2626
function loss_n_ode(p)
2727
pred = predict(p)
28-
return sum(abs2, correct_pos .- pred[1:2, :]), pred
28+
return sum(abs2, correct_pos .- pred[1:2, :])
2929
end
3030

3131
l1 = loss_n_ode(p)
3232

33-
function callback(p, l, pred)
33+
function callback(p, l)
3434
@info "[SecondOrderODE] Loss: $l"
3535
return l < 0.01
3636
end
@@ -52,7 +52,7 @@
5252

5353
function loss_n_ode(p)
5454
pred = predict(p)
55-
return sum(abs2, correct_pos .- pred[1:2, :]), pred
55+
return sum(abs2, correct_pos .- pred[1:2, :])
5656
end
5757

5858
optfunc = Optimization.OptimizationFunction(
@@ -72,7 +72,7 @@
7272

7373
function loss_n_ode(p)
7474
pred = predict(p)
75-
return sum(abs2, correct_pos .- pred[1:2, :]), pred
75+
return sum(abs2, correct_pos .- pred[1:2, :])
7676
end
7777

7878
optfunc = Optimization.OptimizationFunction(

0 commit comments

Comments
 (0)